Skip to content

Commit 60b93f8

Browse files
quic-boyucclaude
andcommitted
observatory: json_report for accuracy + per_layer_accuracy lenses
Adds Frontend.json_report() overrides to the two demo lenses, producing machine-readable Report (JSON) content suited for CI, LLM triage, and regression detection. accuracy (_AccuracyFrontend.json_report): Aggregates primary metrics (psnr, mse, cosine_sim, top_k, ...) across session_records: mean, min, max, and worst_record. Internal _* keys, per-sample _min/_max stats, and _worst_idx indices are excluded. worst_record semantics: for quality metrics (psnr, cosine_sim, top_k) worst_record = argmin (lower value = worse quality); for error metrics (mse, abs_err) worst_record = argmax (higher value = worse quality). per_layer_accuracy (_PerLayerAccuracyFrontend.json_report): Reports anchor/target record names, n_layers, sample_source, metric_ranges (from analyze() global_data), and worst_layers: top-N rows per metric sorted worst-first. - psnr / cosine_sim: ascending sort (lower = worse). - mse / abs_err: descending sort (higher = worse). - Layer identity: uses from_node_root, falling back to target_node, matching the real row schema from observe(). - top_n: read from analysis.global_data["json_report_top_n"] (stored by analyze() from config["per_layer_accuracy"]["json_report_top_n"], default 10). No live config-stack access at call time. Tests (test_json_report.py, 17 tests): - Framework: invocation count, payload shape, no-ghost-keys for None returns, compare-mode archive grouping, export indentation, NaN survival. - AccuracyLens: mean/min/max aggregation, internal-key exclusion, mse worst_record = argmax, records_measured count. - PerLayerAccuracyLens: None when no data, sort direction for psnr (ascending) and mse (descending), metric_ranges propagation, top_n config knob via analyze(). Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
1 parent d6b3391 commit 60b93f8

3 files changed

Lines changed: 619 additions & 20 deletions

File tree

devtools/observatory/lenses/accuracy.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,68 @@ def dashboard(self, session, session_records, analysis) -> Optional[ViewList]:
743743
]
744744
)
745745

746+
def json_report(self, session, session_records, analysis) -> Optional[Dict[str, Any]]:
747+
"""Aggregate accuracy metrics across the session's records.
748+
749+
Each numeric primary metric contributes a ``mean``, ``min``, ``max``,
750+
and ``worst_record`` (the record name where the metric was lowest,
751+
indicating the worst quality sample for that metric).
752+
Internal ``_*`` keys, ``_min``/``_max`` per-sample stats, and
753+
``_worst_idx`` indices are excluded.
754+
"""
755+
sums: Dict[str, float] = {}
756+
counts: Dict[str, int] = {}
757+
mins: Dict[str, float] = {}
758+
maxs: Dict[str, float] = {}
759+
worst: Dict[str, str] = {}
760+
measured = 0
761+
# Error metrics: higher value = worse quality → worst_record = argmax.
762+
# Quality metrics (psnr, cosine_sim, top_k, ...): lower = worse → argmin.
763+
_ERROR_METRICS = {"mse", "abs_err"}
764+
765+
for rec in session_records or []:
766+
digest = rec.data.get("accuracy")
767+
if not isinstance(digest, dict):
768+
continue
769+
measured += 1
770+
for k, v in digest.items():
771+
if not isinstance(v, (int, float)) or isinstance(v, bool):
772+
continue
773+
if k.startswith("_"):
774+
continue
775+
if k.endswith(("_min", "_max", "_worst_idx")):
776+
continue
777+
sums[k] = sums.get(k, 0.0) + float(v)
778+
counts[k] = counts.get(k, 0) + 1
779+
if float(v) < mins.get(k, float("inf")):
780+
mins[k] = float(v)
781+
# Quality metric (psnr, cosine_sim, ...): lower = worse.
782+
if k not in _ERROR_METRICS:
783+
worst[k] = rec.name
784+
if float(v) > maxs.get(k, float("-inf")):
785+
maxs[k] = float(v)
786+
# Error metric (mse, abs_err, ...): higher = worse.
787+
if k in _ERROR_METRICS:
788+
worst[k] = rec.name
789+
790+
if measured == 0:
791+
return None
792+
793+
return {
794+
"records_measured": measured,
795+
"metrics": {
796+
k: {
797+
"mean": round(sums[k] / counts[k], 4),
798+
"min": round(mins[k], 4),
799+
"max": round(maxs[k], 4),
800+
# For quality metrics (psnr etc.): min = worst.
801+
# For error metrics (mse etc.): max = worst.
802+
"worst_record": worst[k],
803+
}
804+
for k in sorted(sums)
805+
},
806+
}
807+
746808
def record(
747809
self, digest: Any, analysis: Dict[str, Any], context: Dict[str, Any]
748810
) -> Optional[ViewList]:

devtools/observatory/lenses/per_layer_accuracy.py

Lines changed: 133 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,43 @@ def _metric_specs() -> Dict[str, Dict[str, Any]]:
566566
},
567567
}
568568

569+
@classmethod
570+
def _build_context_extension(
571+
cls,
572+
rows: List[Dict[str, Any]],
573+
) -> GraphExtension:
574+
"""Umbrella extension carrying identity/context fields once per node.
575+
576+
Each metric extension stores only its own value, so the per-node
577+
identity (target_node, anchor_node, shapes, topo indices, key_kind,
578+
from_node_root, numel_compared, sparse_match_key) lives here exactly
579+
once instead of being duplicated four times across the metric
580+
extensions.
581+
"""
582+
ext = GraphExtension(
583+
id="ctx",
584+
name="Per-Layer Accuracy Context",
585+
)
586+
for row in rows:
587+
node_id = str(row["target_node"])
588+
ext.add_node_data(
589+
node_id,
590+
{
591+
"sparse_match_key": row.get("match_key", ""),
592+
"key_kind": row.get("key_kind", ""),
593+
"from_node_root": row.get("from_node_root", ""),
594+
"anchor_node": row.get("anchor_node", ""),
595+
"target_node": row.get("target_node", ""),
596+
"anchor_topo_index": row.get("anchor_topo_index", -1),
597+
"target_topo_index": row.get("target_topo_index", -1),
598+
"numel_compared": row.get("numel_compared", 0),
599+
"anchor_shape": row.get("anchor_shape", "n/a"),
600+
"target_shape": row.get("target_shape", "n/a"),
601+
},
602+
)
603+
ext.set_sync_key("sparse_match_key")
604+
return ext
605+
569606
@classmethod
570607
def _build_metric_extension(
571608
cls,
@@ -581,23 +618,13 @@ def _build_metric_extension(
581618
ext = GraphExtension(id=metric_name, name=str(spec["name"]))
582619
for row in rows:
583620
node_id = str(row["target_node"])
584-
info = {
585-
"sparse_match_key": row.get("match_key", ""),
586-
"key_kind": row.get("key_kind", ""),
587-
"from_node_root": row.get("from_node_root", ""),
588-
"anchor_node": row.get("anchor_node", ""),
589-
"target_node": row.get("target_node", ""),
590-
"anchor_topo_index": row.get("anchor_topo_index", -1),
591-
"target_topo_index": row.get("target_topo_index", -1),
592-
"numel_compared": row.get("numel_compared", 0),
593-
"anchor_shape": row.get("anchor_shape", "n/a"),
594-
"target_shape": row.get("target_shape", "n/a"),
595-
"psnr": cls._safe_float(row.get("psnr", 0.0)),
596-
"cosine_sim": cls._safe_float(row.get("cosine_sim", 0.0)),
597-
"mse": cls._safe_float(row.get("mse", 0.0)),
598-
"abs_err": cls._safe_float(row.get("abs_err", 0.0)),
599-
}
600-
ext.add_node_data(node_id, info)
621+
ext.add_node_data(
622+
node_id,
623+
{
624+
"sparse_match_key": row.get("match_key", ""),
625+
metric_name: cls._safe_float(row.get(metric_name, 0.0)),
626+
},
627+
)
601628

602629
ext.set_sync_key("sparse_match_key")
603630

@@ -611,13 +638,16 @@ def _label_formatter(d: Dict[str, Any]) -> List[str]:
611638
primary_label = str(spec["label"])
612639
return [f"{primary_label}={_format_metric_value(primary)}"]
613640

641+
# All metric extensions opt into label formatters; the run-time LRU
642+
# in the JS canvas (cap = MAX_LABEL_EXTENSIONS = 2) keeps at most two
643+
# active at once, and the build-time bbox reservation in
644+
# FXGraphExporter sizes nodes for two label rows.
614645
ext.set_label_formatter(_label_formatter)
615646

616647
def _tooltip_formatter(d: Dict[str, Any]) -> List[str]:
617648
primary = cls._safe_float(d.get(metric_name, 0.0))
618649
primary_label = str(spec["label"])
619650
return [
620-
f"target_node={d.get('target_node', 'n/a')}",
621651
f"match_key={d.get('sparse_match_key', '')}",
622652
f"{primary_label}={_format_metric_value(primary, tooltip=True)}",
623653
]
@@ -674,6 +704,12 @@ def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResu
674704
metric_ranges = PerLayerAccuracyLens._aggregate_metric_ranges(records)
675705
if metric_ranges:
676706
result.global_data["metric_ranges"] = metric_ranges
707+
# Stash json_report_top_n so json_report can read it from analysis.global_data
708+
# without touching the live config stack (which may be empty at render time).
709+
top_n = int(
710+
(config.get("per_layer_accuracy") or {}).get("json_report_top_n", 10)
711+
)
712+
result.global_data["json_report_top_n"] = top_n
677713

678714
for record in records:
679715
digest = record.data.get("per_layer_accuracy")
@@ -691,11 +727,18 @@ def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResu
691727
}
692728
)
693729

730+
# Identity context, contributed once per node, regardless of which
731+
# metric layers the user has active. Replaces the per-metric-ext
732+
# duplication of identity fields.
733+
analysis.add_graph_layer(
734+
"ctx", PerLayerAccuracyLens._build_context_extension(rows)
735+
)
736+
694737
for metric_name in ("cosine_sim", "psnr", "mse", "abs_err"):
695738
r = metric_ranges.get(metric_name)
696739
fixed_range = (r[0], r[1]) if r else None
697740
metric_ext = PerLayerAccuracyLens._build_metric_extension(
698-
rows, metric_name, fixed_range=fixed_range
741+
rows, metric_name, fixed_range=fixed_range,
699742
)
700743
analysis.add_graph_layer(metric_name, metric_ext)
701744

@@ -924,7 +967,10 @@ def record(
924967
id="per_layer_accuracy_graph",
925968
title="Per-layer Accuracy Graph",
926969
graph_ref=graph_ref,
927-
default_layers=[f"{lens_name}/cosine_sim"],
970+
default_layers=[
971+
f"{lens_name}/ctx",
972+
f"{lens_name}/cosine_sim",
973+
],
928974
default_color_by=f"{lens_name}/cosine_sim",
929975
compare=GraphCompareSpec(
930976
default_sync={
@@ -955,6 +1001,73 @@ def check_badges(
9551001
]
9561002
return []
9571003

1004+
def json_report(
1005+
self, session, session_records, analysis
1006+
) -> Optional[Dict[str, Any]]:
1007+
"""Machine-readable per-layer accuracy summary for Report (JSON).
1008+
1009+
Finds the first target record in ``session_records`` that carries
1010+
per-layer row data, then returns:
1011+
1012+
- ``anchor``/``target``: record names
1013+
- ``n_layers``: number of matched layer pairs
1014+
- ``sample_source``: how the sample index was chosen
1015+
- ``metric_ranges``: {metric: [min, max]} from the archive-wide analysis
1016+
- ``worst_layers``: {metric: [top-N rows sorted worst-first]}
1017+
1018+
Top-N depth is controlled by
1019+
``config["per_layer_accuracy"]["json_report_top_n"]`` (default 10).
1020+
``config`` is passed via ``analyze()`` → ``analysis.global_data``
1021+
so the value is available even when called outside a live
1022+
``Observatory.enable_context`` block.
1023+
"""
1024+
target_record = next(
1025+
(
1026+
r
1027+
for r in session_records or []
1028+
if isinstance(r.data.get("per_layer_accuracy"), dict)
1029+
and r.data["per_layer_accuracy"].get("rows")
1030+
),
1031+
None,
1032+
)
1033+
if target_record is None:
1034+
return None
1035+
1036+
digest = target_record.data["per_layer_accuracy"]
1037+
rows = digest.get("rows") or []
1038+
global_data = analysis.global_data or {}
1039+
metric_ranges = global_data.get("metric_ranges") or {}
1040+
top_n = int(global_data.get("json_report_top_n", 10))
1041+
1042+
worst: Dict[str, Any] = {}
1043+
for metric in ("psnr", "cosine_sim", "mse", "abs_err"):
1044+
ranked = [r for r in rows if isinstance(r.get(metric), (int, float))]
1045+
# psnr / cosine_sim: lower == worse (ascending sort = worst first)
1046+
# mse / abs_err: higher == worse (descending sort = worst first)
1047+
reverse = metric in ("mse", "abs_err")
1048+
ranked.sort(key=lambda r: r[metric], reverse=reverse)
1049+
entry = []
1050+
for r in ranked[:top_n]:
1051+
# Layer identity: prefer the human-readable from_node_root;
1052+
# fall back to the target_node id for unrooted nodes.
1053+
layer_id = r.get("from_node_root") or r.get("target_node", "?")
1054+
row: Dict[str, Any] = {"layer": layer_id}
1055+
for m in ("psnr", "cosine_sim", "mse", "abs_err"):
1056+
v = r.get(m)
1057+
if isinstance(v, (int, float)):
1058+
row[m] = round(float(v), 4)
1059+
entry.append(row)
1060+
worst[metric] = entry
1061+
1062+
return {
1063+
"anchor": digest.get("anchor_record"),
1064+
"target": target_record.name,
1065+
"n_layers": len(rows),
1066+
"sample_source": digest.get("sample_source"),
1067+
"metric_ranges": metric_ranges,
1068+
"worst_layers": worst,
1069+
}
1070+
9581071
@staticmethod
9591072
def get_frontend_spec() -> Frontend:
9601073
return PerLayerAccuracyLens._PerLayerAccuracyFrontend()

0 commit comments

Comments
 (0)