|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +from dataclasses import fields |
| 8 | + |
7 | 9 | from executorch.devtools.observatory.graph_hub import GraphHub |
8 | | -from executorch.devtools.observatory.interfaces import RecordAnalysis |
| 10 | +from executorch.devtools.observatory.interfaces import ( |
| 11 | + GraphLayerContribution, |
| 12 | + RecordAnalysis, |
| 13 | +) |
9 | 14 | from executorch.devtools.fx_viewer import ( |
| 15 | + GraphExtension, |
10 | 16 | GraphExtensionNodePayload, |
11 | 17 | GraphExtensionPayload, |
12 | 18 | ) |
@@ -54,3 +60,189 @@ def test_build_viewer_payload() -> None: |
54 | 60 | payload = GraphHub.build_viewer_payload(graph_assets, graph_layers, "r1") |
55 | 61 | assert payload["base"]["nodes"][0]["id"] == "a" |
56 | 62 | assert "x/y" in payload["extensions"] |
| 63 | + |
| 64 | + |
| 65 | +# --------------------------------------------------------------------------- |
| 66 | +# Python ↔ JS extension-payload schema contract |
| 67 | +# |
| 68 | +# JS GraphDataStore reads `name`, `legend`, `nodes`, `sync_keys`, and |
| 69 | +# `has_label_formatter` off `payload.extensions[extId]`. The JS LRU and the |
| 70 | +# layer panel's "L" badge depend on `has_label_formatter`, so a silent drop |
| 71 | +# of any of these fields breaks node-label rendering on canvas. |
| 72 | +# --------------------------------------------------------------------------- |
| 73 | + |
| 74 | +# Fields that GraphHub MUST forward verbatim from a GraphExtensionPayload into |
| 75 | +# the per-layer dict that lands in `payload.extensions[extId]` for the JS |
| 76 | +# runtime. `id` is intentionally excluded — it lives in the dict key. |
| 77 | +_REQUIRED_EXTENSION_DICT_FIELDS = frozenset( |
| 78 | + f.name for f in fields(GraphExtensionPayload) if f.name != "id" |
| 79 | +) |
| 80 | + |
| 81 | + |
| 82 | +def _make_label_extension( |
| 83 | + ext_id: str = "m", *, with_label: bool = True |
| 84 | +) -> GraphExtension: |
| 85 | + ext = GraphExtension(id=ext_id, name="Metric") |
| 86 | + ext.add_node_data("n0", {"foo": 1.0}) |
| 87 | + if with_label: |
| 88 | + ext.set_label_formatter(lambda d: [f"foo={d.get('foo', 0):.2f}"]) |
| 89 | + return ext |
| 90 | + |
| 91 | + |
| 92 | +def test_graph_hub_preserves_has_label_formatter() -> None: |
| 93 | + """JS reads `has_label_formatter` to gate label rendering and the L badge.""" |
| 94 | + hub = GraphHub() |
| 95 | + hub.register_asset("r0", {"legend": [], "nodes": [{"id": "n0"}], "edges": []}, {}) |
| 96 | + analysis = RecordAnalysis() |
| 97 | + analysis.add_graph_layer("m", _make_label_extension()) |
| 98 | + |
| 99 | + hub.add_analysis_layers("r0", "lens", analysis) |
| 100 | + |
| 101 | + slot = hub.build_payload()["graph_layers"]["r0"]["lens/m"] |
| 102 | + assert slot.get("has_label_formatter") is True |
| 103 | + |
| 104 | + |
| 105 | +def test_graph_hub_no_label_formatter_emits_false() -> None: |
| 106 | + """Layer without a label formatter must serialize an explicit `False`.""" |
| 107 | + hub = GraphHub() |
| 108 | + hub.register_asset("r0", {"legend": [], "nodes": [{"id": "n0"}], "edges": []}, {}) |
| 109 | + analysis = RecordAnalysis() |
| 110 | + analysis.add_graph_layer("m", _make_label_extension(with_label=False)) |
| 111 | + |
| 112 | + hub.add_analysis_layers("r0", "lens", analysis) |
| 113 | + |
| 114 | + slot = hub.build_payload()["graph_layers"]["r0"]["lens/m"] |
| 115 | + assert slot.get("has_label_formatter") is False |
| 116 | + |
| 117 | + |
| 118 | +def test_graph_hub_layer_dict_has_full_extension_contract() -> None: |
| 119 | + """All non-id GraphExtensionPayload fields must round-trip through GraphHub. |
| 120 | +
|
| 121 | + Adding a new field to GraphExtensionPayload without forwarding it through |
| 122 | + GraphHub silently breaks the JS contract. This test pins the contract. |
| 123 | + """ |
| 124 | + hub = GraphHub() |
| 125 | + hub.register_asset("r0", {"legend": [], "nodes": [{"id": "n0"}], "edges": []}, {}) |
| 126 | + payload = GraphExtensionPayload( |
| 127 | + id="m", |
| 128 | + name="Metric", |
| 129 | + legend=[{"label": "L", "color": "#000"}], |
| 130 | + sync_keys=["debug_handle"], |
| 131 | + has_label_formatter=True, |
| 132 | + nodes={"n0": GraphExtensionNodePayload(label_append=["x=1"])}, |
| 133 | + ) |
| 134 | + analysis = RecordAnalysis() |
| 135 | + analysis.add_graph_layer("m", payload) |
| 136 | + hub.add_analysis_layers("r0", "lens", analysis) |
| 137 | + |
| 138 | + slot = hub.build_payload()["graph_layers"]["r0"]["lens/m"] |
| 139 | + missing = _REQUIRED_EXTENSION_DICT_FIELDS - set(slot) |
| 140 | + assert not missing, f"GraphHub dropped fields: {sorted(missing)}" |
| 141 | + |
| 142 | + |
| 143 | +def test_graph_hub_node_payload_dict_has_full_node_contract() -> None: |
| 144 | + """Per-node extension entries must carry every GraphExtensionNodePayload field.""" |
| 145 | + hub = GraphHub() |
| 146 | + hub.register_asset("r0", {"legend": [], "nodes": [{"id": "n0"}], "edges": []}, {}) |
| 147 | + payload = GraphExtensionPayload( |
| 148 | + id="m", |
| 149 | + name="Metric", |
| 150 | + nodes={ |
| 151 | + "n0": GraphExtensionNodePayload( |
| 152 | + info={"k": 1}, |
| 153 | + tooltip=["t"], |
| 154 | + label_append=["lbl"], |
| 155 | + fill_color="#abc", |
| 156 | + ) |
| 157 | + }, |
| 158 | + ) |
| 159 | + analysis = RecordAnalysis() |
| 160 | + analysis.add_graph_layer("m", payload) |
| 161 | + hub.add_analysis_layers("r0", "lens", analysis) |
| 162 | + |
| 163 | + node_dict = hub.build_payload()["graph_layers"]["r0"]["lens/m"]["nodes"]["n0"] |
| 164 | + expected_fields = {f.name for f in fields(GraphExtensionNodePayload)} |
| 165 | + missing = expected_fields - set(node_dict) |
| 166 | + assert not missing, f"node payload dropped fields: {sorted(missing)}" |
| 167 | + |
| 168 | + |
| 169 | +def test_to_payload_with_overrides_preserves_has_label_formatter() -> None: |
| 170 | + """`id_override`/`name_override` must not strip the rest of the contract. |
| 171 | +
|
| 172 | + The override branch in interfaces.GraphLayerContribution.to_payload re- |
| 173 | + constructs a GraphExtensionPayload; if it omits a field, JS loses it. |
| 174 | + """ |
| 175 | + contribution = GraphLayerContribution( |
| 176 | + extension=_make_label_extension(), |
| 177 | + id_override="renamed_id", |
| 178 | + name_override="Renamed", |
| 179 | + ) |
| 180 | + payload = contribution.to_payload() |
| 181 | + |
| 182 | + assert payload.id == "renamed_id" |
| 183 | + assert payload.name == "Renamed" |
| 184 | + assert payload.has_label_formatter is True |
| 185 | + |
| 186 | + |
| 187 | +def test_to_payload_with_overrides_preserves_sync_keys() -> None: |
| 188 | + payload_in = GraphExtensionPayload( |
| 189 | + id="m", |
| 190 | + name="Metric", |
| 191 | + sync_keys=["debug_handle", "from_node"], |
| 192 | + has_label_formatter=True, |
| 193 | + nodes={"n0": GraphExtensionNodePayload(label_append=["x"])}, |
| 194 | + ) |
| 195 | + contribution = GraphLayerContribution( |
| 196 | + extension=payload_in, |
| 197 | + id_override="renamed", |
| 198 | + ) |
| 199 | + payload_out = contribution.to_payload() |
| 200 | + |
| 201 | + assert payload_out.sync_keys == ["debug_handle", "from_node"] |
| 202 | + assert payload_out.has_label_formatter is True |
| 203 | + |
| 204 | + |
| 205 | +def test_graph_hub_preserves_has_label_formatter_through_overrides() -> None: |
| 206 | + """End-to-end: lens contributes with overrides → JS-bound dict still flagged.""" |
| 207 | + hub = GraphHub() |
| 208 | + hub.register_asset("r0", {"legend": [], "nodes": [{"id": "n0"}], "edges": []}, {}) |
| 209 | + analysis = RecordAnalysis() |
| 210 | + analysis.add_graph_layer( |
| 211 | + "m", |
| 212 | + _make_label_extension(), |
| 213 | + id_override="metric_v2", |
| 214 | + name_override="Metric V2", |
| 215 | + ) |
| 216 | + |
| 217 | + hub.add_analysis_layers("r0", "lens", analysis) |
| 218 | + |
| 219 | + slot = hub.build_payload()["graph_layers"]["r0"]["lens/m"] |
| 220 | + assert slot.get("name") == "Metric V2" |
| 221 | + assert slot.get("has_label_formatter") is True |
| 222 | + |
| 223 | + |
| 224 | +def test_build_viewer_payload_includes_layout_constants() -> None: |
| 225 | + """JS reads `payload.layout` for line_height / max_label_extensions / etc. |
| 226 | +
|
| 227 | + Without this, JS silently uses the fallback constants in graph_data_store.js; |
| 228 | + those happen to match today's Python defaults but will drift on first |
| 229 | + Python-side change to `_NODE_LINE_HEIGHT` or `_MAX_LABEL_EXTENSIONS`. |
| 230 | + """ |
| 231 | + from executorch.devtools.fx_viewer.exporter import FXGraphExporter |
| 232 | + |
| 233 | + graph_assets = { |
| 234 | + "r1": {"base": {"legend": [], "nodes": [{"id": "a"}], "edges": []}, "meta": {}} |
| 235 | + } |
| 236 | + payload = GraphHub.build_viewer_payload(graph_assets, {"r1": {}}, "r1") |
| 237 | + |
| 238 | + assert "layout" in payload |
| 239 | + expected = FXGraphExporter._layout_constants_payload() |
| 240 | + for key in ("line_height", "y_padding", "max_label_extensions", "base_font_px"): |
| 241 | + assert payload["layout"][key] == expected[key] |
| 242 | + |
| 243 | + |
| 244 | +def test_build_viewer_payload_layout_present_for_missing_graph_ref() -> None: |
| 245 | + """Empty asset path must still emit layout — JS fallback path shouldn't kick in.""" |
| 246 | + payload = GraphHub.build_viewer_payload({}, {}, "missing") |
| 247 | + assert "layout" in payload |
| 248 | + assert "max_label_extensions" in payload["layout"] |
0 commit comments