Skip to content

Commit 30eab74

Browse files
committed
Fixed tensor inspector
1 parent 131db7e commit 30eab74

7 files changed

Lines changed: 293 additions & 31 deletions

File tree

examples/einsum_demo.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ def _keep_trace_tensors_alive(trace: Any, *tensors: Any) -> None:
6868
trace._example_keepalive = keepalive
6969

7070

71+
def _cumulative_group_contraction_scheme(
72+
groups: tuple[tuple[str, ...], ...],
73+
) -> tuple[tuple[str, ...], ...]:
74+
if not groups:
75+
return ()
76+
running_names: list[str] = []
77+
steps: list[tuple[str, ...]] = []
78+
for group in groups:
79+
running_names.extend(group)
80+
steps.append(tuple(running_names))
81+
return tuple(steps)
82+
83+
7184
def _site_bond_dims(n_sites: int) -> list[int]:
7285
return [2 + (index % 3) for index in range(max(n_sites - 1, 1))]
7386

@@ -586,16 +599,13 @@ def _renderable_trace(trace: Any, args: ExampleCliArgs) -> Any:
586599

587600
def _scheme_steps(name: str, args: ExampleCliArgs) -> tuple[tuple[str, ...], ...] | None:
588601
if name == "mps":
589-
return cumulative_prefix_contraction_scheme(
590-
tuple(f"A{i}" for i in range(args.n_sites))
591-
+ tuple(f"x{i}" for i in range(args.n_sites))
602+
return _cumulative_group_contraction_scheme(
603+
tuple((f"A{i}", f"x{i}") for i in range(args.n_sites))
592604
)
593605
if name == "mpo":
594-
mpo_names = tuple(f"W{i}" for i in range(args.n_sites))
595-
vectors = tuple(f"d{i}" for i in range(args.n_sites)) + tuple(
596-
f"u{i}" for i in range(args.n_sites)
606+
return _cumulative_group_contraction_scheme(
607+
tuple((f"W{i}", f"d{i}", f"u{i}") for i in range(args.n_sites))
597608
)
598-
return cumulative_prefix_contraction_scheme(mpo_names + vectors)
599609
if name == "peps":
600610
names = tuple(f"P{i}_{j}" for i in range(args.lx) for j in range(args.ly))
601611
return cumulative_prefix_contraction_scheme(names)

src/tensor_network_viz/_interaction/controller.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from .._tensor_elements_support import _TensorRecord
3535
from .._typing import root_figure
36-
from .._ui_utils import _set_axes_visible, _set_figure_bottom_reserved
36+
from .._ui_utils import _set_axes_visible, _set_figure_bottom_reserved, _style_control_tray_axes
3737
from ..config import EngineName, PlotConfig, ViewName
3838
from ..contraction_viewer import _MAIN_FIGURE_BOTTOM_RESERVED, _PLAYBACK_DETAILS_TOP
3939
from ..einsum_module.trace import EinsumTrace
@@ -73,22 +73,27 @@
7373
_INTERACTIVE_CHECK_FRAME_PROPS: dict[str, float] = {"s": 44.0, "linewidth": 0.9}
7474
_INTERACTIVE_CHECK_MARK_PROPS: dict[str, float] = {"s": 34.0, "linewidth": 1.0}
7575
_INTERACTIVE_RADIO_PROPS: dict[str, float] = {"s": 38.0, "linewidth": 0.9}
76-
_CONTROL_TRAY_FACE: tuple[float, float, float] = (0.97, 0.97, 0.99)
77-
_CONTROL_TRAY_FRAME: tuple[float, float, float] = (0.78, 0.78, 0.82)
7876

7977

80-
def _style_interactive_control_axes(ax: Axes) -> None:
81-
ax.set_xticks([])
82-
ax.set_yticks([])
83-
ax.set_navigate(False)
84-
ax.patch.set_facecolor(_CONTROL_TRAY_FACE)
85-
ax.patch.set_alpha(0.88)
86-
ax.patch.set_edgecolor(_CONTROL_TRAY_FRAME)
87-
ax.patch.set_linewidth(0.6)
88-
for spine in ax.spines.values():
89-
spine.set_visible(True)
90-
spine.set_linewidth(0.6)
91-
spine.set_color(_CONTROL_TRAY_FRAME)
78+
def _reveal_auxiliary_figure(figure: Figure) -> None:
79+
manager = getattr(figure.canvas, "manager", None)
80+
manager_show = getattr(manager, "show", None)
81+
if callable(manager_show):
82+
with suppress(AttributeError, RuntimeError, TypeError, ValueError):
83+
manager_show()
84+
else:
85+
figure_show = getattr(figure, "show", None)
86+
if callable(figure_show):
87+
with suppress(AttributeError, RuntimeError, TypeError, ValueError):
88+
figure_show()
89+
draw_idle = getattr(figure.canvas, "draw_idle", None)
90+
if callable(draw_idle):
91+
with suppress(AttributeError, RuntimeError, TypeError, ValueError):
92+
draw_idle()
93+
flush_events = getattr(figure.canvas, "flush_events", None)
94+
if callable(flush_events):
95+
with suppress(AttributeError, RuntimeError, TypeError, ValueError):
96+
flush_events()
9297

9398

9499
class _LinkedTensorInspectorController:
@@ -111,6 +116,10 @@ def __init__(
111116
self._closing_programmatically: bool = False
112117
self._close_cid: int | None = None
113118

119+
@property
120+
def is_enabled(self) -> bool:
121+
return self._enabled
122+
114123
def bind_viewer(self, viewer: Any) -> None:
115124
if self._viewer is viewer:
116125
if self._enabled and self._viewer is not None:
@@ -128,9 +137,11 @@ def bind_viewer(self, viewer: Any) -> None:
128137
call_immediately=self._enabled,
129138
)
130139

131-
def set_enabled(self, enabled: bool) -> None:
140+
def set_enabled(self, enabled: bool, *, reveal: bool = False) -> None:
132141
target = bool(enabled)
133142
if target == self._enabled:
143+
if target and reveal and self._figure is not None:
144+
_reveal_auxiliary_figure(self._figure)
134145
if target and self._viewer is not None:
135146
self._sync_to_step(int(self._viewer.current_step))
136147
return
@@ -139,6 +150,8 @@ def set_enabled(self, enabled: bool) -> None:
139150
self._close_figure()
140151
return
141152
self._ensure_figure()
153+
if reveal and self._figure is not None:
154+
_reveal_auxiliary_figure(self._figure)
142155
if self._viewer is not None:
143156
self._sync_to_step(int(self._viewer.current_step))
144157
else:
@@ -295,6 +308,7 @@ def __init__(
295308
self.figure: Figure | None = None
296309
self._tensor_inspector: _LinkedTensorInspectorController | None = None
297310
self._figure_close_cid: int | None = None
311+
self._initialized: bool = False
298312
if self.tensor_inspector_available:
299313
self._tensor_inspector = _LinkedTensorInspectorController(
300314
trace=cast(EinsumTrace, network),
@@ -314,6 +328,7 @@ def initialize(self) -> tuple[Figure, RenderedAxes]:
314328
return figure, ax
315329
self._build_controls()
316330
self._apply_scene_state(self.current_scene)
331+
self._initialized = True
317332
set_interactive_controls(figure, self)
318333
set_active_axes(figure, ax)
319334
figure._tensor_network_viz_tensor_inspector = self._tensor_inspector # type: ignore[attr-defined]
@@ -436,7 +451,7 @@ def _build_controls(self) -> None:
436451
)
437452
cb_bottom = float(cb_bounds[1])
438453
check_ax = self.figure.add_axes(cb_bounds)
439-
_style_interactive_control_axes(check_ax)
454+
_style_control_tray_axes(check_ax)
440455
self._check_ax = check_ax
441456
if not self._external_ax:
442457
radio_bounds: tuple[float, float, float, float] = (
@@ -446,7 +461,7 @@ def _build_controls(self) -> None:
446461
_VIEW_SELECTOR_HEIGHT,
447462
)
448463
radio_ax = self.figure.add_axes(radio_bounds)
449-
_style_interactive_control_axes(radio_ax)
464+
_style_control_tray_axes(radio_ax)
450465
self._radio_ax = radio_ax
451466
active_index = 0 if self.current_view == "2d" else 1
452467
self._radio = RadioButtons(
@@ -569,7 +584,15 @@ def _apply_scene_state(self, scene: _InteractiveSceneState) -> None:
569584
if self._tensor_inspector is not None:
570585
self._tensor_inspector.bind_viewer(controls._viewer)
571586
if self._tensor_inspector is not None:
572-
self._tensor_inspector.set_enabled(self.tensor_inspector_on)
587+
reveal_inspector = bool(
588+
self._initialized
589+
and self.tensor_inspector_on
590+
and not self._tensor_inspector.is_enabled
591+
)
592+
self._tensor_inspector.set_enabled(
593+
self.tensor_inspector_on,
594+
reveal=reveal_inspector,
595+
)
573596
_apply_scene_hover_state(scene, hover_on=self.hover_on)
574597
self._sync_checkbuttons()
575598
if not self._external_ax:

src/tensor_network_viz/_tensor_elements_controller.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_TextSummaryPayload,
2323
_valid_group_modes_for_record,
2424
)
25-
from ._ui_utils import _reserve_figure_bottom
25+
from ._ui_utils import _reserve_figure_bottom, _style_control_tray_axes
2626
from .tensor_elements_config import TensorElementsConfig, TensorElementsMode
2727

2828
TensorElementsGroup = Literal["basic", "complex", "diagnostic"]
@@ -32,9 +32,9 @@
3232
]
3333

3434
_GROUP_OPTIONS: Final[tuple[TensorElementsGroup, ...]] = ("basic", "complex", "diagnostic")
35-
_GROUP_SELECTOR_BOUNDS: Final[tuple[float, float, float, float]] = (0.02, 0.048, 0.12, 0.12)
36-
_MODE_SELECTOR_BOUNDS: Final[tuple[float, float, float, float]] = (0.16, 0.028, 0.18, 0.18)
37-
_TENSOR_SLIDER_BOUNDS: Final[tuple[float, float, float, float]] = (0.42, 0.058, 0.38, 0.03)
35+
_GROUP_SELECTOR_BOUNDS: Final[tuple[float, float, float, float]] = (0.02, 0.048, 0.15, 0.12)
36+
_MODE_SELECTOR_BOUNDS: Final[tuple[float, float, float, float]] = (0.19, 0.028, 0.21, 0.16)
37+
_TENSOR_SLIDER_BOUNDS: Final[tuple[float, float, float, float]] = (0.48, 0.052, 0.38, 0.05)
3838
_TENSOR_ELEMENTS_CONTROLS_BOTTOM: Final[float] = 0.24
3939
_INTERACTIVE_LABEL_PROPS: Final[dict[str, Sequence[Any]]] = {"fontsize": [9.5]}
4040
_INTERACTIVE_RADIO_PROPS: Final[dict[str, float]] = {"s": 38.0, "linewidth": 0.9}
@@ -209,6 +209,7 @@ def initialize(self, *, show_controls: bool) -> None:
209209
if show_controls:
210210
_reserve_figure_bottom(self._figure, _TENSOR_ELEMENTS_CONTROLS_BOTTOM)
211211
self._group_radio_ax = self._figure.add_axes(_GROUP_SELECTOR_BOUNDS)
212+
_style_control_tray_axes(self._group_radio_ax)
212213
self._group_radio = RadioButtons(
213214
self._group_radio_ax,
214215
_GROUP_OPTIONS,
@@ -221,6 +222,7 @@ def initialize(self, *, show_controls: bool) -> None:
221222

222223
if len(self._records) > 1:
223224
self._slider_ax = self._figure.add_axes(_TENSOR_SLIDER_BOUNDS)
225+
_style_control_tray_axes(self._slider_ax)
224226
self._slider = Slider(
225227
self._slider_ax,
226228
"Tensor",
@@ -246,6 +248,7 @@ def _rebuild_mode_radio(self) -> None:
246248
if self._mode_radio_ax is not None:
247249
self._mode_radio_ax.remove()
248250
self._mode_radio_ax = self._figure.add_axes(_MODE_SELECTOR_BOUNDS)
251+
_style_control_tray_axes(self._mode_radio_ax)
249252
mode_options = self._current_group_modes()
250253
active_index = mode_options.index(self._mode)
251254
self._mode_radio = RadioButtons(

src/tensor_network_viz/_ui_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
from contextlib import suppress
44
from typing import Any
55

6+
from matplotlib.axes import Axes
67
from matplotlib.figure import Figure
78

89
from ._matplotlib_state import get_reserved_bottom, set_reserved_bottom
910

11+
_CONTROL_TRAY_FACE: tuple[float, float, float] = (0.97, 0.97, 0.99)
12+
_CONTROL_TRAY_FRAME: tuple[float, float, float] = (0.78, 0.78, 0.82)
13+
1014

1115
def _reserve_figure_bottom(fig: Figure, bottom: float) -> None:
1216
current = get_reserved_bottom(fig)
@@ -40,9 +44,26 @@ def _set_widget_active(widget: Any, active: bool) -> None:
4044
setter(bool(active))
4145

4246

47+
def _style_control_tray_axes(ax: Axes) -> None:
48+
ax.set_xticks([])
49+
ax.set_yticks([])
50+
ax.set_navigate(False)
51+
ax.patch.set_facecolor(_CONTROL_TRAY_FACE)
52+
ax.patch.set_alpha(0.88)
53+
ax.patch.set_edgecolor(_CONTROL_TRAY_FRAME)
54+
ax.patch.set_linewidth(0.6)
55+
for spine in ax.spines.values():
56+
spine.set_visible(True)
57+
spine.set_linewidth(0.6)
58+
spine.set_color(_CONTROL_TRAY_FRAME)
59+
60+
4361
__all__ = [
62+
"_CONTROL_TRAY_FACE",
63+
"_CONTROL_TRAY_FRAME",
4464
"_reserve_figure_bottom",
4565
"_set_figure_bottom_reserved",
4666
"_set_axes_visible",
4767
"_set_widget_active",
68+
"_style_control_tray_axes",
4869
]

tests/test_examples.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pytest
1010

11-
from tensor_network_viz import EinsumTrace
11+
from tensor_network_viz import EinsumTrace, show_tensor_network
1212
from tensor_network_viz._tensor_elements_data import _extract_einsum_playback_step_records
1313

1414
_EXAMPLES = Path(__file__).resolve().parent.parent / "examples"
@@ -256,6 +256,67 @@ def test_einsum_auto_examples_keep_tensors_alive_for_tensor_inspector(example_na
256256
assert all(step.record is not None for step in step_records)
257257

258258

259+
@pytest.mark.parametrize("example_name", ["mps", "mpo"])
260+
def test_einsum_auto_examples_keep_inspector_and_costs_aligned_to_real_trace_steps(
261+
example_name: str,
262+
) -> None:
263+
_require_torch()
264+
run_demo = _load_example_module(
265+
Path("examples/run_demo.py"),
266+
f"run_demo_{example_name}_inspector_alignment",
267+
)
268+
demo_cli = importlib.import_module("demo_cli")
269+
einsum_demo = importlib.import_module("einsum_demo")
270+
271+
args = run_demo.parse_args(
272+
["einsum", example_name, "--view", "2d", "--tensor-inspector", "--no-show"]
273+
)
274+
trace = einsum_demo._trace_steps_for(example_name, args)
275+
scheme_steps = einsum_demo._scheme_steps(example_name, args)
276+
277+
assert isinstance(trace, EinsumTrace)
278+
assert scheme_steps is not None
279+
assert len(scheme_steps) == len(trace)
280+
281+
config = demo_cli.finalize_demo_plot_config(
282+
args,
283+
engine="einsum",
284+
scheme_tensor_names=scheme_steps,
285+
)
286+
fig, _ax = show_tensor_network(
287+
trace,
288+
engine="einsum",
289+
view=args.view,
290+
config=config,
291+
show=False,
292+
)
293+
294+
controls = getattr(fig, "_tensor_network_viz_interactive_controls", None)
295+
assert controls is not None
296+
assert controls.current_scene.contraction_controls is not None
297+
viewer = controls.current_scene.contraction_controls._viewer
298+
assert viewer is not None
299+
assert viewer.current_step == len(trace)
300+
301+
inspector = getattr(fig, "_tensor_network_viz_tensor_inspector", None)
302+
assert inspector is not None
303+
assert inspector._figure is not None
304+
inspector_controls = getattr(
305+
inspector._figure,
306+
"_tensor_network_viz_tensor_elements_controls",
307+
None,
308+
)
309+
assert inspector_controls is not None
310+
assert f"r{len(trace) - 1}" in inspector_controls._panel.main_ax.get_title()
311+
312+
controls.cost_hover_on = True
313+
controls._apply_scene_state(controls.current_scene)
314+
assert viewer._cost_panel_ax is not None
315+
assert viewer._cost_panel_ax.get_visible()
316+
assert viewer._cost_text_artist is not None
317+
assert "Contraction:" in viewer._cost_text_artist.get_text()
318+
319+
259320
def test_tensornetwork_mera_ttn_saves_figure_without_showing() -> None:
260321
_require_tensornetwork()
261322
module = _load_example_module(Path("examples/run_demo.py"), "run_demo_tn_mera_ttn")

tests/test_plotting.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from matplotlib.collections import LineCollection
1717

1818
import tensor_network_viz._core.renderer as core_renderer_module
19+
import tensor_network_viz._interaction.controller as interaction_controller_module
1920
import tensor_network_viz.tensorkrowch.graph as tk_graph_module
2021
import tensor_network_viz.tensorkrowch.renderer as tk_renderer_module
2122
import tensor_network_viz.tensornetwork.graph as tn_graph_module
@@ -851,6 +852,46 @@ def test_show_tn_einsum_trace_inspector_checkbox_auto_enables_playback() -> None
851852
assert getattr(fig, "_tensor_network_viz_tensor_inspector", None) is not None
852853

853854

855+
def test_show_tn_reenabling_tensor_inspector_reveals_auxiliary_window(
856+
monkeypatch: pytest.MonkeyPatch,
857+
) -> None:
858+
trace = _build_einsum_trace_for_inspector()
859+
revealed: list[matplotlib.figure.Figure] = []
860+
861+
monkeypatch.setattr(
862+
interaction_controller_module,
863+
"_reveal_auxiliary_figure",
864+
lambda figure: revealed.append(figure),
865+
)
866+
867+
fig, _ax = show_tensor_network(
868+
trace,
869+
config=PlotConfig(
870+
contraction_tensor_inspector=False,
871+
),
872+
show=False,
873+
)
874+
875+
controls = getattr(fig, "_tensor_network_viz_interactive_controls", None)
876+
assert controls is not None
877+
assert controls._checkbuttons is not None
878+
879+
_click_checkbutton(controls._checkbuttons, 6)
880+
881+
inspector = getattr(fig, "_tensor_network_viz_tensor_inspector", None)
882+
assert inspector is not None
883+
assert inspector._figure is not None
884+
assert revealed == [inspector._figure]
885+
886+
_click_checkbutton(controls._checkbuttons, 6)
887+
assert inspector._figure is None
888+
889+
_click_checkbutton(controls._checkbuttons, 6)
890+
891+
assert inspector._figure is not None
892+
assert revealed == [revealed[0], inspector._figure]
893+
894+
854895
def test_show_tensor_network_non_einsum_inputs_do_not_expose_tensor_inspector_checkbox() -> None:
855896
left = DummyTensorKrowchNode("A", ["left"])
856897
right = DummyTensorKrowchNode("B", ["right"])

0 commit comments

Comments
 (0)