@@ -566,6 +566,43 @@ def _metric_specs() -> Dict[str, Dict[str, Any]]:
566566 },
567567 }
568568
569+ @classmethod
570+ def _build_context_extension (
571+ cls ,
572+ rows : List [Dict [str , Any ]],
573+ ) -> GraphExtension :
574+ """Umbrella extension carrying identity/context fields once per node.
575+
576+ Each metric extension stores only its own value, so the per-node
577+ identity (target_node, anchor_node, shapes, topo indices, key_kind,
578+ from_node_root, numel_compared, sparse_match_key) lives here exactly
579+ once instead of being duplicated four times across the metric
580+ extensions.
581+ """
582+ ext = GraphExtension (
583+ id = "ctx" ,
584+ name = "Per-Layer Accuracy Context" ,
585+ )
586+ for row in rows :
587+ node_id = str (row ["target_node" ])
588+ ext .add_node_data (
589+ node_id ,
590+ {
591+ "sparse_match_key" : row .get ("match_key" , "" ),
592+ "key_kind" : row .get ("key_kind" , "" ),
593+ "from_node_root" : row .get ("from_node_root" , "" ),
594+ "anchor_node" : row .get ("anchor_node" , "" ),
595+ "target_node" : row .get ("target_node" , "" ),
596+ "anchor_topo_index" : row .get ("anchor_topo_index" , - 1 ),
597+ "target_topo_index" : row .get ("target_topo_index" , - 1 ),
598+ "numel_compared" : row .get ("numel_compared" , 0 ),
599+ "anchor_shape" : row .get ("anchor_shape" , "n/a" ),
600+ "target_shape" : row .get ("target_shape" , "n/a" ),
601+ },
602+ )
603+ ext .set_sync_key ("sparse_match_key" )
604+ return ext
605+
569606 @classmethod
570607 def _build_metric_extension (
571608 cls ,
@@ -581,23 +618,13 @@ def _build_metric_extension(
581618 ext = GraphExtension (id = metric_name , name = str (spec ["name" ]))
582619 for row in rows :
583620 node_id = str (row ["target_node" ])
584- info = {
585- "sparse_match_key" : row .get ("match_key" , "" ),
586- "key_kind" : row .get ("key_kind" , "" ),
587- "from_node_root" : row .get ("from_node_root" , "" ),
588- "anchor_node" : row .get ("anchor_node" , "" ),
589- "target_node" : row .get ("target_node" , "" ),
590- "anchor_topo_index" : row .get ("anchor_topo_index" , - 1 ),
591- "target_topo_index" : row .get ("target_topo_index" , - 1 ),
592- "numel_compared" : row .get ("numel_compared" , 0 ),
593- "anchor_shape" : row .get ("anchor_shape" , "n/a" ),
594- "target_shape" : row .get ("target_shape" , "n/a" ),
595- "psnr" : cls ._safe_float (row .get ("psnr" , 0.0 )),
596- "cosine_sim" : cls ._safe_float (row .get ("cosine_sim" , 0.0 )),
597- "mse" : cls ._safe_float (row .get ("mse" , 0.0 )),
598- "abs_err" : cls ._safe_float (row .get ("abs_err" , 0.0 )),
599- }
600- ext .add_node_data (node_id , info )
621+ ext .add_node_data (
622+ node_id ,
623+ {
624+ "sparse_match_key" : row .get ("match_key" , "" ),
625+ metric_name : cls ._safe_float (row .get (metric_name , 0.0 )),
626+ },
627+ )
601628
602629 ext .set_sync_key ("sparse_match_key" )
603630
@@ -611,13 +638,16 @@ def _label_formatter(d: Dict[str, Any]) -> List[str]:
611638 primary_label = str (spec ["label" ])
612639 return [f"{ primary_label } ={ _format_metric_value (primary )} " ]
613640
641+ # All metric extensions opt into label formatters; the run-time LRU
642+ # in the JS canvas (cap = MAX_LABEL_EXTENSIONS = 2) keeps at most two
643+ # active at once, and the build-time bbox reservation in
644+ # FXGraphExporter sizes nodes for two label rows.
614645 ext .set_label_formatter (_label_formatter )
615646
616647 def _tooltip_formatter (d : Dict [str , Any ]) -> List [str ]:
617648 primary = cls ._safe_float (d .get (metric_name , 0.0 ))
618649 primary_label = str (spec ["label" ])
619650 return [
620- f"target_node={ d .get ('target_node' , 'n/a' )} " ,
621651 f"match_key={ d .get ('sparse_match_key' , '' )} " ,
622652 f"{ primary_label } ={ _format_metric_value (primary , tooltip = True )} " ,
623653 ]
@@ -674,6 +704,12 @@ def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResu
674704 metric_ranges = PerLayerAccuracyLens ._aggregate_metric_ranges (records )
675705 if metric_ranges :
676706 result .global_data ["metric_ranges" ] = metric_ranges
707+ # Stash json_report_top_n so json_report can read it from analysis.global_data
708+ # without touching the live config stack (which may be empty at render time).
709+ top_n = int (
710+ (config .get ("per_layer_accuracy" ) or {}).get ("json_report_top_n" , 10 )
711+ )
712+ result .global_data ["json_report_top_n" ] = top_n
677713
678714 for record in records :
679715 digest = record .data .get ("per_layer_accuracy" )
@@ -691,11 +727,18 @@ def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResu
691727 }
692728 )
693729
730+ # Identity context, contributed once per node, regardless of which
731+ # metric layers the user has active. Replaces the per-metric-ext
732+ # duplication of identity fields.
733+ analysis .add_graph_layer (
734+ "ctx" , PerLayerAccuracyLens ._build_context_extension (rows )
735+ )
736+
694737 for metric_name in ("cosine_sim" , "psnr" , "mse" , "abs_err" ):
695738 r = metric_ranges .get (metric_name )
696739 fixed_range = (r [0 ], r [1 ]) if r else None
697740 metric_ext = PerLayerAccuracyLens ._build_metric_extension (
698- rows , metric_name , fixed_range = fixed_range
741+ rows , metric_name , fixed_range = fixed_range ,
699742 )
700743 analysis .add_graph_layer (metric_name , metric_ext )
701744
@@ -924,7 +967,10 @@ def record(
924967 id = "per_layer_accuracy_graph" ,
925968 title = "Per-layer Accuracy Graph" ,
926969 graph_ref = graph_ref ,
927- default_layers = [f"{ lens_name } /cosine_sim" ],
970+ default_layers = [
971+ f"{ lens_name } /ctx" ,
972+ f"{ lens_name } /cosine_sim" ,
973+ ],
928974 default_color_by = f"{ lens_name } /cosine_sim" ,
929975 compare = GraphCompareSpec (
930976 default_sync = {
@@ -955,6 +1001,73 @@ def check_badges(
9551001 ]
9561002 return []
9571003
1004+ def json_report (
1005+ self , session , session_records , analysis
1006+ ) -> Optional [Dict [str , Any ]]:
1007+ """Machine-readable per-layer accuracy summary for Report (JSON).
1008+
1009+ Finds the first target record in ``session_records`` that carries
1010+ per-layer row data, then returns:
1011+
1012+ - ``anchor``/``target``: record names
1013+ - ``n_layers``: number of matched layer pairs
1014+ - ``sample_source``: how the sample index was chosen
1015+ - ``metric_ranges``: {metric: [min, max]} from the archive-wide analysis
1016+ - ``worst_layers``: {metric: [top-N rows sorted worst-first]}
1017+
1018+ Top-N depth is controlled by
1019+ ``config["per_layer_accuracy"]["json_report_top_n"]`` (default 10).
1020+ ``config`` is passed via ``analyze()`` → ``analysis.global_data``
1021+ so the value is available even when called outside a live
1022+ ``Observatory.enable_context`` block.
1023+ """
1024+ target_record = next (
1025+ (
1026+ r
1027+ for r in session_records or []
1028+ if isinstance (r .data .get ("per_layer_accuracy" ), dict )
1029+ and r .data ["per_layer_accuracy" ].get ("rows" )
1030+ ),
1031+ None ,
1032+ )
1033+ if target_record is None :
1034+ return None
1035+
1036+ digest = target_record .data ["per_layer_accuracy" ]
1037+ rows = digest .get ("rows" ) or []
1038+ global_data = analysis .global_data or {}
1039+ metric_ranges = global_data .get ("metric_ranges" ) or {}
1040+ top_n = int (global_data .get ("json_report_top_n" , 10 ))
1041+
1042+ worst : Dict [str , Any ] = {}
1043+ for metric in ("psnr" , "cosine_sim" , "mse" , "abs_err" ):
1044+ ranked = [r for r in rows if isinstance (r .get (metric ), (int , float ))]
1045+ # psnr / cosine_sim: lower == worse (ascending sort = worst first)
1046+ # mse / abs_err: higher == worse (descending sort = worst first)
1047+ reverse = metric in ("mse" , "abs_err" )
1048+ ranked .sort (key = lambda r : r [metric ], reverse = reverse )
1049+ entry = []
1050+ for r in ranked [:top_n ]:
1051+ # Layer identity: prefer the human-readable from_node_root;
1052+ # fall back to the target_node id for unrooted nodes.
1053+ layer_id = r .get ("from_node_root" ) or r .get ("target_node" , "?" )
1054+ row : Dict [str , Any ] = {"layer" : layer_id }
1055+ for m in ("psnr" , "cosine_sim" , "mse" , "abs_err" ):
1056+ v = r .get (m )
1057+ if isinstance (v , (int , float )):
1058+ row [m ] = round (float (v ), 4 )
1059+ entry .append (row )
1060+ worst [metric ] = entry
1061+
1062+ return {
1063+ "anchor" : digest .get ("anchor_record" ),
1064+ "target" : target_record .name ,
1065+ "n_layers" : len (rows ),
1066+ "sample_source" : digest .get ("sample_source" ),
1067+ "metric_ranges" : metric_ranges ,
1068+ "worst_layers" : worst ,
1069+ }
1070+
9581071 @staticmethod
9591072 def get_frontend_spec () -> Frontend :
9601073 return PerLayerAccuracyLens ._PerLayerAccuracyFrontend ()
0 commit comments