Skip to content

Commit a66f82a

Browse files
quic-boyucclaude
andcommitted
fx_viewer: label budget, bbox shrink, info-panel dedup + LRU
Three interrelated fixes for per-layer-accuracy usability in Observatory: **Build-time label cap (exporter.py)** - Add `_MAX_LABEL_EXTENSIONS = 2` (cap on label rows reserved in the layout slot) and `_layout_constants_payload()` which emits line_height, y_padding, font sizes, and `max_label_extensions` into `payload.layout` so the JS runtime reads the same values Python used for layout. - `_select_top_label_extensions()`: for each node, picks the 2 extensions with the largest label text to reserve bbox space, keeping the reserved slot bounded to 2 extra rows regardless of how many label-bearing layers exist. - `_compute_layout` and `relayout_payload_base` both use the new selector. - `GraphExtensionPayload.has_label_formatter` flag (models.py + extension.py) lets JS identify which extensions are label-bearing without inspecting nodes. - `GraphPayload.layout` field carries the constants block. **JS LRU queue (graph_data_store.js + view_controller.js + ui_manager.js)** - `GraphDataStore` gets `labelLru` (Set, insertion-order LRU, cap from `layoutConstants.max_label_extensions`) and `recordLabelActivation / recordLabelDeactivation / extensionHasLabels` methods. - `upsertExtension` preserves `has_label_formatter` and `sync_keys`. - `computeActiveGraph` gates `label_append` on LRU membership; non-LRU labeled extensions still contribute color/info/tooltip. - `ViewerController.setState` diffs prev/next active set, calling `recordLabelActivation` for newly-activated label-bearing exts (LRU evicts oldest on overflow) and `recordLabelDeactivation` for newly-removed. - Seeded at construction time so default layers render labels on first paint. - Layer panel shows a small "L" badge on the ≤2 extensions whose labels are currently visible; badge synced in `syncControlsFromState`. **Canvas visual-rect shrink + edge re-anchor (canvas_renderer.js)** - `_visualHeight(node)`: `Math.max(floor, Math.min(desired, node.height))`. Hard-clamped to Python-reserved `node.height` so no rendering can exceed layout. Floor ensures a rect is always drawn. - `_visualBottom(node)`: `node.y - node.height/2 + _visualHeight(node)`. Top of rect is fixed at the layout-slot top; bottom floats. - `_effectiveSourceStart(edge)`: substitutes the polyline's first point y with source's `_visualBottom`, re-anchoring outgoing edges to the visible bottom of the drawn rect. Target-side (incoming) endpoints unchanged. - All node draw calls (fillRect/strokeRect), label positioning, highlight overlays, and hit-testing updated to use `_visualHeight`/`_visualBottom`/ `_effectiveSourceStart`. - `edge.bounds.minY` widened at `_initTopology` time by the worst-case source-endpoint substitution delta so AABB hover-test covers the new range. - JS magic literals (16, 14px, 12px) replaced with `layoutConstants.*`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 54dfde3 commit a66f82a

7 files changed

Lines changed: 264 additions & 37 deletions

File tree

devtools/fx_viewer/exporter.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,18 @@ def __init__(self, graph_module: torch.fx.GraphModule):
5353
_NODE_X_PADDING = 20
5454
_NODE_LINE_HEIGHT = 16
5555
_NODE_Y_PADDING = 20
56+
_NODE_BASE_FONT_PX = 14
57+
_NODE_EXT_FONT_PX = 12
5658
_LAYOUT_XSPACE = 50
5759
_LAYOUT_YSPACE = 30
5860
_DUMMY_SIZE_X = 100 # dummy nodes (from fast-sugiyama) occupy no real width/height
5961
_DUMMY_SIZE_Y = 30 # dummy nodes (from fast-sugiyama) occupy no real width/height
6062
_SPINE_COHESION_ITER = 20
63+
# JS canvas LRU enforces the same cap. Layout reserves vertical space for
64+
# the two extensions whose label formatters produce the most text per node,
65+
# so any active subset of size <= _MAX_LABEL_EXTENSIONS fits without
66+
# overlap. The constant is also emitted to the JS via _layout_constants_payload.
67+
_MAX_LABEL_EXTENSIONS = 2
6168

6269
def _default_base_label(self, node: GraphNode) -> str:
6370
target = str(node.info.get("target") or node.info.get("op") or "")
@@ -237,6 +244,55 @@ def _ext_label_lines_for_layout(self, extension: GraphExtension, node_id: str) -
237244
context=f"extension '{extension.id}' label formatter(node='{node_id}')",
238245
)
239246

247+
@classmethod
248+
def _select_top_label_extensions(
249+
cls,
250+
per_ext_lines: dict[str, list[str]],
251+
) -> list[str]:
252+
"""Pick label lines from at most ``_MAX_LABEL_EXTENSIONS`` extensions.
253+
254+
For layout sizing only. Selects the extensions whose label formatter
255+
produces the largest text (sum of character lengths across the
256+
extension's lines for that node) and concatenates their lines. The JS
257+
canvas LRU caps the same number of extensions at run time, so any
258+
active subset fits within the reserved bbox.
259+
"""
260+
if not per_ext_lines:
261+
return []
262+
if len(per_ext_lines) <= cls._MAX_LABEL_EXTENSIONS:
263+
flat: list[str] = []
264+
for lines in per_ext_lines.values():
265+
flat.extend(lines)
266+
return flat
267+
ranked = sorted(
268+
per_ext_lines.values(),
269+
key=lambda lns: sum(len(s) for s in lns),
270+
reverse=True,
271+
)[: cls._MAX_LABEL_EXTENSIONS]
272+
flat = []
273+
for lines in ranked:
274+
flat.extend(lines)
275+
return flat
276+
277+
@classmethod
278+
def _layout_constants_payload(cls) -> dict[str, Any]:
279+
"""Layout/font constants embedded into ``payload.layout`` for the JS runtime.
280+
281+
Single source of truth for sizing. Any drift between Python build-time
282+
layout and JS render-time drawing is eliminated by reading these values
283+
on the JS side instead of hard-coding them in the canvas renderer.
284+
"""
285+
return {
286+
"line_height": cls._NODE_LINE_HEIGHT,
287+
"y_padding": cls._NODE_Y_PADDING,
288+
"x_padding": cls._NODE_X_PADDING,
289+
"char_width": cls._NODE_CHAR_WIDTH,
290+
"min_width": cls._NODE_MIN_WIDTH,
291+
"base_font_px": cls._NODE_BASE_FONT_PX,
292+
"ext_font_px": cls._NODE_EXT_FONT_PX,
293+
"max_label_extensions": cls._MAX_LABEL_EXTENSIONS,
294+
}
295+
240296
@classmethod
241297
def _compute_node_box_size(
242298
cls,
@@ -619,10 +675,14 @@ def _segment_crosses_aabb(
619675
def _compute_layout(self, nodes: dict[str, GraphNode], edges: list[GraphEdge]) -> None:
620676
ext_label_lines_by_node: dict[str, list[str]] = {}
621677
for node_id in nodes:
622-
ext_lines: list[str] = []
678+
per_ext_lines: dict[str, list[str]] = {}
623679
for ext in self.extensions:
624-
ext_lines.extend(self._ext_label_lines_for_layout(ext, node_id))
625-
ext_label_lines_by_node[node_id] = ext_lines
680+
lines = self._ext_label_lines_for_layout(ext, node_id)
681+
if lines:
682+
per_ext_lines[ext.id] = lines
683+
ext_label_lines_by_node[node_id] = self._select_top_label_extensions(
684+
per_ext_lines
685+
)
626686

627687
self._compute_layout_with_ext_lines(
628688
nodes,
@@ -695,7 +755,9 @@ def relayout_payload_base(
695755
else:
696756
active_layer_ids = [layer_id for layer_id in include_layers if layer_id in ext_payloads]
697757

698-
ext_label_lines_by_node: dict[str, list[str]] = {node_id: [] for node_id in nodes}
758+
per_node_per_ext: dict[str, dict[str, list[str]]] = {
759+
node_id: {} for node_id in nodes
760+
}
699761
for layer_id in active_layer_ids:
700762
layer_payload = ext_payloads.get(layer_id)
701763
if not isinstance(layer_payload, dict):
@@ -704,11 +766,16 @@ def relayout_payload_base(
704766
if not isinstance(layer_nodes, dict):
705767
continue
706768
for node_id, node_payload in layer_nodes.items():
707-
if node_id not in ext_label_lines_by_node or not isinstance(node_payload, dict):
769+
if node_id not in per_node_per_ext or not isinstance(node_payload, dict):
708770
continue
709-
ext_label_lines_by_node[node_id].extend(
710-
cls._coerce_str_lines(node_payload.get("label_append"))
711-
)
771+
lines = cls._coerce_str_lines(node_payload.get("label_append"))
772+
if lines:
773+
per_node_per_ext[node_id][layer_id] = lines
774+
775+
ext_label_lines_by_node: dict[str, list[str]] = {
776+
node_id: cls._select_top_label_extensions(per_ext_lines)
777+
for node_id, per_ext_lines in per_node_per_ext.items()
778+
}
712779

713780
cls._compute_layout_with_ext_lines(
714781
nodes,
@@ -769,7 +836,11 @@ def generate_json_payload(self) -> Dict[str, Any]:
769836
self._compute_layout(nodes, edges)
770837
base_payload = self._build_base_payload(nodes, edges)
771838
extensions_payload = self._build_extensions_payload()
772-
payload = GraphPayload(base=base_payload, extensions=extensions_payload)
839+
payload = GraphPayload(
840+
base=base_payload,
841+
extensions=extensions_payload,
842+
layout=self._layout_constants_payload(),
843+
)
773844
return asdict(payload)
774845

775846
def export_json(self, output_path: str):

devtools/fx_viewer/extension.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def build_payload(self) -> GraphExtensionPayload:
129129
legend=legend,
130130
nodes=compiled_nodes,
131131
sync_keys=list(self.sync_keys),
132+
has_label_formatter=self.label_formatter is not None,
132133
)
133134

134135
def build(self) -> Dict[str, Any]:

devtools/fx_viewer/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ class GraphExtensionPayload:
5757
legend: list[dict[str, str]] = field(default_factory=list)
5858
nodes: dict[str, GraphExtensionNodePayload] = field(default_factory=dict)
5959
sync_keys: list[str] = field(default_factory=list)
60+
has_label_formatter: bool = False
6061

6162

6263
@dataclass
6364
class GraphPayload:
6465
base: BaseGraphPayload
6566
extensions: dict[str, GraphExtensionPayload]
67+
layout: dict[str, Any] = field(default_factory=dict)

devtools/fx_viewer/templates/canvas_renderer.js

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,20 @@ class CanvasRenderer {
135135
detectHover(graphX, graphY) {
136136
let nearestNode = null;
137137
let nearestEdge = null;
138-
138+
139139
for (let i = 0; i < this.viewer.store.baseData.nodes.length; i++) {
140140
const node = this.viewer.store.baseData.nodes[i];
141+
const active = this.viewer.store.activeNodeMap.get(node.id) || node;
141142
const w = node.width;
142-
const h = node.height;
143+
const h = this._visualHeight(active);
144+
const top = node.y - node.height / 2;
143145
if (graphX >= node.x - w/2 && graphX <= node.x + w/2 &&
144-
graphY >= node.y - h/2 && graphY <= node.y + h/2) {
146+
graphY >= top && graphY <= top + h) {
145147
nearestNode = node.id;
146-
break;
148+
break;
147149
}
148150
}
149-
151+
150152
if (!nearestNode) {
151153
const transform = this.viewer.controller.transform;
152154
const hoverDist = 5 / transform.k;
@@ -160,7 +162,11 @@ class CanvasRenderer {
160162
const w = this.viewer.store.activeNodeMap.get(edge.w);
161163
let min_d = Infinity;
162164
if (edge.points && edge.points.length > 0) {
163-
for (let j = 0; j < edge.points.length - 1; j++) {
165+
const start = this._effectiveSourceStart(edge) || edge.points[0];
166+
const a = start;
167+
const b = edge.points.length > 1 ? edge.points[1] : edge.points[0];
168+
min_d = Math.min(min_d, this.distToSegment({x: graphX, y: graphY}, a, b));
169+
for (let j = 1; j < edge.points.length - 1; j++) {
164170
const d = this.distToSegment({x: graphX, y: graphY}, edge.points[j], edge.points[j+1]);
165171
min_d = Math.min(min_d, d);
166172
}
@@ -185,6 +191,42 @@ class CanvasRenderer {
185191
return Math.hypot(p.x - (v.x + t * (w.x - v.x)), p.y - (v.y + t * (w.y - v.y)));
186192
}
187193

194+
/**
195+
* Visual height of a node's drawn rect. Sized for the actual count of
196+
* label_append entries (LRU-filtered upstream in graph_data_store), but
197+
* clamped to never exceed the Python-reserved layout slot (`node.height`).
198+
* Floor at one base label line so empty-label nodes still draw a sane rect.
199+
*/
200+
_visualHeight(node) {
201+
const layout = this.viewer.store.layoutConstants;
202+
const labelLines = (node && node.label_append) ? node.label_append.length : 0;
203+
const desired = (1 + labelLines) * layout.line_height + layout.y_padding;
204+
const floor = layout.line_height + layout.y_padding;
205+
const ceiling = node && node.height ? node.height : desired;
206+
return Math.max(floor, Math.min(desired, ceiling));
207+
}
208+
209+
/**
210+
* Y-coordinate of the bottom edge of the drawn rect. Top of the rect is
211+
* the layout-authoritative `node.y - node.height/2`; bottom floats with
212+
* the active label count. Used to re-anchor source-side edge endpoints.
213+
*/
214+
_visualBottom(node) {
215+
return node.y - node.height / 2 + this._visualHeight(node);
216+
}
217+
218+
/**
219+
* Source-side edge attach point: the polyline's first waypoint with the y
220+
* shifted to the source node's current visual bottom. Target-side
221+
* endpoint is unchanged (top of next-layer node is fixed).
222+
*/
223+
_effectiveSourceStart(edge) {
224+
if (!edge.points || edge.points.length === 0) return null;
225+
const source = this.viewer.store.activeNodeMap.get(edge.v);
226+
if (!source) return { x: edge.points[0].x, y: edge.points[0].y };
227+
return { x: edge.points[0].x, y: this._visualBottom(source) };
228+
}
229+
188230
render() {
189231
const dpr = window.devicePixelRatio || 1;
190232
const ctx = this.ctx;
@@ -274,7 +316,8 @@ class CanvasRenderer {
274316
ctx.beginPath();
275317
let midX = 0, midY = 0;
276318
if (edge.points && edge.points.length > 0) {
277-
ctx.moveTo(edge.points[0].x, edge.points[0].y);
319+
const start = this._effectiveSourceStart(edge) || edge.points[0];
320+
ctx.moveTo(start.x, start.y);
278321
for (let i = 1; i < edge.points.length; i++) {
279322
ctx.lineTo(edge.points[i].x, edge.points[i].y);
280323
}
@@ -387,8 +430,11 @@ class CanvasRenderer {
387430
ctx.fillStyle = renderedFill;
388431
}
389432

390-
ctx.fillRect(node.x - node.width/2, node.y - node.height/2, node.width, node.height);
391-
433+
const layout = this.viewer.store.layoutConstants;
434+
const visH = this._visualHeight(node);
435+
const rectTop = node.y - node.height / 2;
436+
ctx.fillRect(node.x - node.width/2, rectTop, node.width, visH);
437+
392438
if (isSelected || isPreview || isHovered) {
393439
ctx.strokeStyle = theme.edgeHover;
394440
if (isSelected || isPreview) {
@@ -401,24 +447,27 @@ class CanvasRenderer {
401447
} else {
402448
ctx.setLineDash([]);
403449
}
404-
ctx.strokeRect(node.x - node.width/2, node.y - node.height/2, node.width, node.height);
450+
ctx.strokeRect(node.x - node.width/2, rectTop, node.width, visH);
405451
ctx.setLineDash([]);
406452
}
407-
453+
408454
ctx.fillStyle = node.fill_color
409455
? (fxReadableTextColor(renderedFill) || theme.text)
410456
: theme.text;
411457
let allLines = [node.label || node.id];
412458
if (node.label_append && node.label_append.length > 0) {
413459
allLines = allLines.concat(node.label_append);
414460
}
415-
416-
const lineHeight = 16;
417-
const startY = node.y - ((allLines.length - 1) * lineHeight) / 2;
418-
461+
462+
const lineHeight = layout.line_height;
463+
// Anchor label rows to the top of the visible rect; previous code
464+
// centered around node.y, which drifted off the visible rect once
465+
// the rect started shrinking from full layout slot.
466+
const startY = rectTop + (layout.y_padding / 2) + lineHeight / 2;
467+
419468
for (let i = 0; i < allLines.length; i++) {
420-
if (i === 0) ctx.font = 'bold 14px sans-serif';
421-
else ctx.font = '12px sans-serif';
469+
if (i === 0) ctx.font = `bold ${layout.base_font_px}px sans-serif`;
470+
else ctx.font = `${layout.ext_font_px}px sans-serif`;
422471
ctx.fillText(allLines[i], node.x, startY + (i * lineHeight));
423472
}
424473

@@ -441,11 +490,13 @@ class CanvasRenderer {
441490
nodeIds.forEach((id) => {
442491
const node = this.viewer.store.activeNodeMap.get(id);
443492
if (!node) return;
493+
const visH = this._visualHeight(node);
494+
const rectTop = node.y - node.height / 2;
444495
ctx.strokeRect(
445496
node.x - node.width / 2 - outerOffset,
446-
node.y - node.height / 2 - outerOffset,
497+
rectTop - outerOffset,
447498
node.width + outerOffset * 2,
448-
node.height + outerOffset * 2
499+
visH + outerOffset * 2
449500
);
450501
});
451502
});
@@ -534,7 +585,7 @@ class CanvasRenderer {
534585

535586
const padding = 8 / transform.k;
536587
const tw = maxW + padding * 2;
537-
const lineHeight = 16 / transform.k;
588+
const lineHeight = this.viewer.store.layoutConstants.line_height / transform.k;
538589
const th = (tooltipLines.length * lineHeight) + padding * 2;
539590

540591
const viewLeft = -transform.x / transform.k;

0 commit comments

Comments
 (0)