Skip to content

Commit 291c072

Browse files
quic-boyucclaude
andcommitted
observatory: pipeline_graph_collector groups records into quantization+edge regions
Brings the AOT pipeline records under two top-level Region groups so the HTML left-panel tree view stays compact and meaningful: Session "<script-name>" ├── quantization/ top-level (aten-graph work) │ ├── Annotated Model (prepare_pt2e) │ ├── Calibrated Model (convert_pt2e input) │ └── Quantized Model (convert_pt2e output) └── edge/ top-level (edge dialect) ├── Pre-EdgeTransform/<method> ├── EdgeProgramManager EP └── etrecord/ lazy nested under edge ├── ETRecord Exported/<method> ├── ETRecord Edge/<method> └── ETRecord Extra/<module> Design rationale: - Quantization outputs (Annotated/Calibrated/Quantized Model) are still aten graphs, not edge dialect, so they belong under their own top-level region rather than under "edge". - ETRecord operations are AOT-time (they save the float aten and edge dialect programs), so they live under "edge" via a nested "etrecord" group rather than appearing as a standalone "runtime" region. - No per-call sub-regions ("prepare_pt2e", "convert_pt2e", "etc.") -- every region holds at least 2 records, and the per-call identity is already in the record name. Implementation: - The runtime region stack only ever holds one chain at a time, so the lens opens "quantization" and "edge" lazily through transition helpers (`_transition_to_quantization`, `_transition_to_edge`) and closes the previous sibling when transitioning. This is monotonic for the typical AOT order (prepare -> convert -> to_edge -> ETRecord) but tolerates a backward transition defensively. - `_ensure_etrecord_region` first transitions to edge, then opens "etrecord" as a child via a second contextlib.ExitStack. - Three ExitStacks total (quantization, edge, etrecord) all close in on_session_end in safe order (innermost first). observatory.py fix: enter_context now pushes the Region onto the stack *before* firing on_session_start so a lens that opens its own enter_context inside that hook (as pipeline_graph_collector does) sees the new frame as an inner Region rather than recursing into a fresh outermost call. tests/test_pipeline_graph_collector_regions.py (new): - 8 tests covering: lazy region opening at session_start, transition to_quantization, forward transition to_edge closes quantization, etrecord nesting under edge, idempotent lazy etrecord open, every region holds >=2 records, on_session_end closes all open stacks, lens hooks fire exactly once per CLI run. Verification: PYTHONPATH=~ python -m pytest \ ~/executorch/devtools/observatory/tests/ -v -> 49 passed, 1 pre-existing unrelated failure (test_per_layer_accuracy_lens, "psnr" key, untouched here). Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
1 parent d86c43a commit 291c072

3 files changed

Lines changed: 366 additions & 10 deletions

File tree

devtools/observatory/lenses/pipeline_graph_collector.py

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,45 @@
1010
capture graph artifacts at each stage of the export → quantize → lower pipeline.
1111
All patches are installed on session start and removed on session end.
1212
13+
Region structure (cf. RFC §4.5):
14+
15+
Session "<script-name>" ← outermost (opened by CLI / user wrapper)
16+
├── quantization/ ← top-level coarse stage; aten-graph work
17+
│ ├── prepare_pt2e/ ← per-call regions
18+
│ │ └── record Annotated Model
19+
│ └── convert_pt2e/
20+
│ ├── record Calibrated Model
21+
│ └── record Quantized Model
22+
└── edge/ ← top-level coarse stage; edge-dialect work
23+
├── to_edge_transform_and_lower/
24+
│ └── record EdgeProgramManager EP (+ Pre-EdgeTransform/<method>)
25+
└── etrecord/ ← lazy nested region under edge; opens on
26+
│ first ETRecord.add_* call
27+
├── exported_program/
28+
│ └── record ETRecord Exported/<method>
29+
├── edge_dialect_program/
30+
│ └── record ETRecord Edge/<method>
31+
└── extra_export_modules/
32+
└── record ETRecord Extra/<module>
33+
34+
`quantization` and `edge` are **sibling** top-level regions (not nested).
35+
Annotated/Calibrated/Quantized models are aten graphs (not edge dialect),
36+
so they live under `quantization`. `to_edge_transform_and_lower` and the
37+
ETRecord operations live under `edge`.
38+
39+
Because the runtime region stack holds only one chain at a time, the lens
40+
opens these top-level regions **lazily** through transition helpers:
41+
`_transition_to_quantization` opens quantization (closing edge+etrecord if
42+
they were open from a backward transition). `_transition_to_edge` closes
43+
quantization (if open) and opens edge. `_ensure_etrecord_region` ensures
44+
edge is open and then opens the nested etrecord region. All open stacks
45+
are closed in `on_session_end`.
46+
47+
Lens lifecycle hooks fire **once** per CLI invocation: `on_session_start`
48+
records the framework's `enter_context`/`collect` callables and installs
49+
patches; `on_session_end` closes any open transition stacks and restores
50+
patches.
51+
1352
Collection points (in pipeline order):
1453
1. torch.export.export → "Exported Float" (ExportedProgram)
1554
2. prepare_pt2e → "Annotated Model" (GraphModule with observers)
@@ -31,6 +70,7 @@
3170

3271
from __future__ import annotations
3372

73+
import contextlib
3474
import logging
3575
from typing import Any, Callable, Dict, List, Optional
3676

@@ -43,6 +83,10 @@ class PipelineGraphCollectorLens(Lens):
4383
_installed: bool = False
4484
_originals: Dict[str, Any] = {}
4585
_collect_fn: Optional[Callable[[str, Any], None]] = None
86+
_enter_context_fn: Optional[Callable[..., Any]] = None
87+
_quantization_stack: Optional[contextlib.ExitStack] = None
88+
_edge_stack: Optional[contextlib.ExitStack] = None
89+
_etrecord_stack: Optional[contextlib.ExitStack] = None
4690
# Cross-lens contract for AccuracyLens fallback dataset.
4791
_last_calibration_dataset: Optional[list] = None
4892
# Backend-specific patch installers registered via register_backend_patches().
@@ -57,9 +101,10 @@ def register_backend_patches(
57101
"""Register a backend-specific patch installer.
58102
59103
The installer receives the lens class and should use cls._originals,
60-
cls._collect_fn, and cls._set_accuracy_fallback_dataset() for
61-
standard integration. It may also append to cls._backend_uninstallers
62-
to register cleanup logic.
104+
cls._collect_fn, cls._enter_context_fn, and
105+
cls._set_accuracy_fallback_dataset() for standard integration.
106+
It may also append to cls._backend_uninstallers to register cleanup
107+
logic.
63108
"""
64109
if installer not in cls._backend_patch_installers:
65110
cls._backend_patch_installers.append(installer)
@@ -76,6 +121,13 @@ def on_session_start(cls, context: ObservationContext) -> None:
76121
from ..observatory import Observatory
77122

78123
cls._collect_fn = Observatory.collect
124+
cls._enter_context_fn = Observatory.enter_context
125+
126+
# Top-level "quantization" and "edge" regions are opened lazily by
127+
# the transition helpers (`_transition_to_quantization`,
128+
# `_transition_to_edge`). The "etrecord" nested region inside edge
129+
# is opened lazily on the first ETRecord.add_* call.
130+
79131
# Install backend-agnostic patches first.
80132
cls._install_quantizer_patches()
81133
cls._install_edge_lower_patch()
@@ -113,6 +165,59 @@ def digest(cls, observation: Any, context: ObservationContext) -> Any:
113165
def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult:
114166
return AnalysisResult()
115167

168+
@classmethod
169+
def _transition_to_quantization(cls) -> None:
170+
"""Ensure the top-level `quantization` region is the active sibling.
171+
172+
Closes edge+etrecord if they were open from a previous transition
173+
(rare backward order; supported defensively).
174+
"""
175+
176+
if cls._enter_context_fn is None:
177+
return
178+
if cls._etrecord_stack is not None:
179+
cls._etrecord_stack.close()
180+
cls._etrecord_stack = None
181+
if cls._edge_stack is not None:
182+
cls._edge_stack.close()
183+
cls._edge_stack = None
184+
if cls._quantization_stack is None:
185+
cls._quantization_stack = contextlib.ExitStack()
186+
cls._quantization_stack.enter_context(
187+
cls._enter_context_fn("quantization")
188+
)
189+
190+
@classmethod
191+
def _transition_to_edge(cls) -> None:
192+
"""Ensure the top-level `edge` region is the active sibling.
193+
194+
Closes the `quantization` region first (if it was open). Idempotent
195+
when edge is already open.
196+
"""
197+
198+
if cls._enter_context_fn is None:
199+
return
200+
if cls._quantization_stack is not None:
201+
cls._quantization_stack.close()
202+
cls._quantization_stack = None
203+
if cls._edge_stack is None:
204+
cls._edge_stack = contextlib.ExitStack()
205+
cls._edge_stack.enter_context(cls._enter_context_fn("edge"))
206+
207+
@classmethod
208+
def _ensure_etrecord_region(cls) -> None:
209+
"""Lazy-open the nested `etrecord` region inside `edge`.
210+
211+
Guarantees that `edge` is the active top-level region first
212+
(calls `_transition_to_edge`), then opens `etrecord` as its child.
213+
Idempotent when etrecord is already open.
214+
"""
215+
216+
cls._transition_to_edge()
217+
if cls._etrecord_stack is None and cls._enter_context_fn is not None:
218+
cls._etrecord_stack = contextlib.ExitStack()
219+
cls._etrecord_stack.enter_context(cls._enter_context_fn("etrecord"))
220+
116221
@classmethod
117222
def _set_accuracy_fallback_dataset(cls, dataset: Any, source: str) -> None:
118223
"""Store dataset for AccuracyLens fallback.
@@ -150,6 +255,7 @@ def _install_quantizer_patches(cls) -> None:
150255
cls._originals["prepare_pt2e"] = original_prepare
151256

152257
def patched_prepare_pt2e(model, *args, **kwargs):
258+
cls._transition_to_quantization()
153259
result = original_prepare(model, *args, **kwargs)
154260
try:
155261
cls._collect_fn("Annotated Model", result)
@@ -168,6 +274,7 @@ def patched_prepare_pt2e(model, *args, **kwargs):
168274
cls._originals["convert_pt2e"] = original_convert
169275

170276
def patched_convert_pt2e(model, *args, **kwargs):
277+
cls._transition_to_quantization()
171278
try:
172279
cls._collect_fn("Calibrated Model", model)
173280
except Exception as exc:
@@ -203,7 +310,7 @@ def patched_convert_pt2e(model, *args, **kwargs):
203310
def _install_edge_lower_patch(cls) -> None:
204311
try:
205312
import executorch.exir.program._program as program_module
206-
import executorch.exir as exir_module
313+
import executorch.exir as exir_module
207314

208315
def _collect_pre_edge_transform_inputs(args, kwargs):
209316
programs = kwargs.get("programs")
@@ -233,11 +340,14 @@ def _collect_pre_edge_transform_inputs(args, kwargs):
233340

234341
def _make_patched_to_edge_transform_and_lower(original_fn):
235342
def patched_to_edge_transform_and_lower(*args, **kwargs):
343+
cls._transition_to_edge()
236344
_collect_pre_edge_transform_inputs(args, kwargs)
237345
kwargs["generate_etrecord"] = True
238346
result = original_fn(*args, **kwargs)
239347
try:
240-
cls._collect_fn("EdgeProgramManager EP", result.exported_program())
348+
cls._collect_fn(
349+
"EdgeProgramManager EP", result.exported_program()
350+
)
241351
except Exception as exc:
242352
logging.debug(
243353
"[PipelineGraphCollector] collect skipped (EdgeProgramManager EP): %s",
@@ -292,6 +402,7 @@ def _safe_collect(name: str, artifact: Any) -> None:
292402

293403
def _wrap_add_exported_program(original):
294404
def wrapped(self, exported_program):
405+
cls._ensure_etrecord_region()
295406
result = original(self, exported_program)
296407
if exported_program is None:
297408
return result
@@ -306,6 +417,7 @@ def wrapped(self, exported_program):
306417

307418
def _wrap_add_edge_dialect_program(original):
308419
def wrapped(self, edge_dialect_program):
420+
cls._ensure_etrecord_region()
309421
result = original(self, edge_dialect_program)
310422
processed = getattr(self, "edge_dialect_program", None)
311423
if isinstance(processed, dict):
@@ -319,6 +431,7 @@ def wrapped(self, edge_dialect_program):
319431

320432
def _wrap_add_extra_export_modules(original):
321433
def wrapped(self, extra_recorded_export_modules):
434+
cls._ensure_etrecord_region()
322435
result = original(self, extra_recorded_export_modules)
323436
graph_map = getattr(self, "graph_map", {}) or {}
324437
for module_name, program in graph_map.items():
@@ -388,6 +501,7 @@ def _uninstall_all(cls) -> None:
388501

389502
cls._originals.clear()
390503
cls._collect_fn = None
504+
cls._enter_context_fn = None
391505
cls._last_calibration_dataset = None
392506
for uninstaller in cls._backend_uninstallers:
393507
try:
@@ -396,5 +510,38 @@ def _uninstall_all(cls) -> None:
396510
logging.warning(
397511
"[PipelineGraphCollector] Backend uninstall failed: %s", exc
398512
)
513+
514+
# Close any open transition regions in reverse-nesting order:
515+
# innermost etrecord first, then edge, then quantization. Whichever
516+
# of edge / quantization is currently active gets closed; the other
517+
# is already None.
518+
if cls._etrecord_stack is not None:
519+
try:
520+
cls._etrecord_stack.close()
521+
except Exception as exc:
522+
logging.warning(
523+
"[PipelineGraphCollector] Failed to close etrecord region: %s",
524+
exc,
525+
)
526+
cls._etrecord_stack = None
527+
if cls._edge_stack is not None:
528+
try:
529+
cls._edge_stack.close()
530+
except Exception as exc:
531+
logging.warning(
532+
"[PipelineGraphCollector] Failed to close edge region: %s",
533+
exc,
534+
)
535+
cls._edge_stack = None
536+
if cls._quantization_stack is not None:
537+
try:
538+
cls._quantization_stack.close()
539+
except Exception as exc:
540+
logging.warning(
541+
"[PipelineGraphCollector] Failed to close quantization region: %s",
542+
exc,
543+
)
544+
cls._quantization_stack = None
545+
399546
cls._installed = False
400547
logging.info("[PipelineGraphCollector] Uninstalled all patches")

devtools/observatory/observatory.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,27 @@ def merge_config_dict(base: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, An
226226

227227
cls._config_stack.append(context_config)
228228

229-
if push_region and is_outermost:
230-
cls._open_session(effective_name)
231-
229+
# Push the Region onto the stack BEFORE firing on_session_start so
230+
# that any lens-side `enter_context` calls in the hook see this
231+
# frame as an outer Region (not as a fresh outermost — which would
232+
# recurse into _open_session).
232233
if push_region:
233234
cls._region_stack.append(effective_name)
234235

236+
if push_region and is_outermost:
237+
cls._open_session(effective_name)
238+
235239
try:
236240
yield
237241
finally:
242+
# Mirror the order in reverse: close Session first (firing
243+
# on_session_end while the region is still on the stack so the
244+
# lens can clean up its own nested regions), then pop.
245+
if push_region and is_outermost:
246+
cls._close_session(effective_name)
238247
if push_region:
239248
cls._region_stack.pop()
240249
cls._config_stack.pop()
241-
if push_region and is_outermost:
242-
cls._close_session(effective_name)
243250

244251
@classmethod
245252
@contextmanager

0 commit comments

Comments
 (0)