Skip to content

Commit f9d829e

Browse files
committed
Improved menu
1 parent 7a582d8 commit f9d829e

8 files changed

Lines changed: 188 additions & 38 deletions

File tree

scripts/verify.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def _command_groups() -> dict[str, VerificationGroup]:
4444
),
4545
),
4646
),
47-
"package": (
48-
VerificationStep("build-dist", (python, "-m", "build", "--sdist", "--wheel")),
49-
),
47+
"package": (VerificationStep("build-dist", (python, "-m", "build", "--sdist", "--wheel")),),
5048
}
5149

5250

src/tensor_network_viz/_core/renderer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
RenderedAxes: TypeAlias = Axes | Axes3D
3636
_Dimensions = Literal[2, 3]
3737
_LayoutCacheKey: TypeAlias = tuple[int, int, int]
38+
# Main axes: use almost the full figure width (interactive widgets sit in the bottom margin).
39+
_FIGURE_ADJUST_LEFT: float = 0.006
40+
_FIGURE_ADJUST_RIGHT: float = 0.994
3841

3942
_layout_positions_by_id: dict[
4043
int,
@@ -423,7 +426,12 @@ def _plot_graph(
423426
build_scene_state=build_scene_state,
424427
)
425428
reserved_bottom = get_reserved_bottom(fig)
426-
fig.subplots_adjust(left=0.02, right=0.98, bottom=reserved_bottom, top=0.98)
429+
fig.subplots_adjust(
430+
left=_FIGURE_ADJUST_LEFT,
431+
right=_FIGURE_ADJUST_RIGHT,
432+
bottom=reserved_bottom,
433+
top=0.98,
434+
)
427435
return fig, resolved_ax
428436

429437

src/tensor_network_viz/_tensor_elements_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class _EinsumPlaybackStepRecord:
5151
result_name: str
5252
record: _TensorRecord | None
5353

54+
5455
def _detect_tensor_elements_engine(data: Any) -> tuple[EngineName, Any]:
5556
return _detect_tensor_engine_with_input(data)
5657

src/tensor_network_viz/_ui_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ def _reserve_figure_bottom(fig: Figure, bottom: float) -> None:
1515
fig.subplots_adjust(bottom=target)
1616

1717

18+
def _set_figure_bottom_reserved(fig: Figure, bottom: float) -> None:
19+
"""Store and apply *bottom*; unlike `_reserve_figure_bottom`, can shrink the reserved strip."""
20+
b = float(bottom)
21+
set_reserved_bottom(fig, b)
22+
p = fig.subplotpars
23+
fig.subplots_adjust(left=p.left, right=p.right, top=p.top, bottom=b)
24+
25+
1826
def _set_axes_visible(ax: Any, visible: bool) -> None:
1927
ax.set_visible(visible)
2028
ax.patch.set_visible(visible)
@@ -32,4 +40,9 @@ def _set_widget_active(widget: Any, active: bool) -> None:
3240
setter(bool(active))
3341

3442

35-
__all__ = ["_reserve_figure_bottom", "_set_axes_visible", "_set_widget_active"]
43+
__all__ = [
44+
"_reserve_figure_bottom",
45+
"_set_figure_bottom_reserved",
46+
"_set_axes_visible",
47+
"_set_widget_active",
48+
]

src/tensor_network_viz/contraction_viewer.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,41 @@
4040
_TNV_CONTRACTION_SCHEME_PATCH_GID: Final[str] = "tnv_contraction_scheme"
4141

4242
_TRANSPARENT: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0)
43-
_PLAYBACK_MAIN_BOTTOM: float = 0.40
4443
_PLAYBACK_DETAILS_BOUNDS: tuple[float, float, float, float] = (0.25, 0.116, 0.68, 0.12)
45-
_PLAYBACK_SLIDER_BOUNDS: tuple[float, float, float, float] = (0.33, 0.067, 0.345, 0.024)
44+
# Top of the cost / step-details axis; interactive chrome (checkboxes, 2d/3d) aligns to this y.
45+
_PLAYBACK_DETAILS_TOP: float = _PLAYBACK_DETAILS_BOUNDS[1] + _PLAYBACK_DETAILS_BOUNDS[3]
46+
# Aligned with the top of the cost / scheme chrome (no extra gap above the widgets).
47+
_MAIN_FIGURE_BOTTOM_RESERVED: float = _PLAYBACK_DETAILS_TOP
48+
_PLAYBACK_MAIN_BOTTOM: float = _MAIN_FIGURE_BOTTOM_RESERVED
49+
_CONTROLS_MAIN_BOTTOM: float = _MAIN_FIGURE_BOTTOM_RESERVED
50+
_PLAYBACK_SLIDER_HEIGHT: float = 0.058
51+
_PLAYBACK_SLIDER_BOUNDS: tuple[float, float, float, float] = (
52+
0.33,
53+
0.062,
54+
0.345,
55+
_PLAYBACK_SLIDER_HEIGHT,
56+
)
57+
_PLAYBACK_SLIDER_HANDLE_STYLE: dict[str, Any] = {
58+
"facecolor": "#2563eb",
59+
"edgecolor": "#1d4ed8",
60+
"size": 11,
61+
}
4662
_PLAYBACK_BUTTON_START_X: float = 0.73
4763
_PLAYBACK_BUTTON_Y: float = 0.058
4864
_PLAYBACK_BUTTON_WIDTH: float = 0.055
4965
_PLAYBACK_BUTTON_HEIGHT: float = 0.038
5066
_PLAYBACK_BUTTON_GAP: float = 0.012
5167
_PLAYBACK_RESET_WIDTH: float = 0.065
52-
_CONTROLS_MAIN_BOTTOM: float = _PLAYBACK_MAIN_BOTTOM
53-
_CONTROLS_CHECKBOX_BOUNDS: tuple[float, float, float, float] = (0.02, 0.045, 0.13, 0.10)
68+
_CONTROLS_CHECKBOX_TOP: float = _PLAYBACK_DETAILS_TOP
69+
_CONTROLS_CHECKBOX_HEIGHT: float = 0.10
70+
_CONTROLS_CHECKBOX_BOUNDS: tuple[float, float, float, float] = (
71+
0.02,
72+
_CONTROLS_CHECKBOX_TOP - _CONTROLS_CHECKBOX_HEIGHT,
73+
0.13,
74+
_CONTROLS_CHECKBOX_HEIGHT,
75+
)
76+
_PLAYBACK_TRAY_FACE: tuple[float, float, float] = (0.96, 0.96, 0.98)
77+
_PLAYBACK_TRAY_FRAME: tuple[float, float, float] = (0.78, 0.78, 0.82)
5478
_SCHEME_LABELS: tuple[str, str, str] = ("Scheme", "Playback", "Costs")
5579
_CONTROL_LABEL_PROPS: dict[str, Sequence[Any]] = {"fontsize": [9.5]}
5680
_CONTROL_FRAME_PROPS: dict[str, float] = {"s": 44.0, "linewidth": 0.9}
@@ -452,9 +476,15 @@ def _build_step_details_panel(self) -> None:
452476
ax_details = self.figure.add_axes(_PLAYBACK_DETAILS_BOUNDS)
453477
ax_details.set_xticks([])
454478
ax_details.set_yticks([])
455-
ax_details.patch.set_alpha(0.0)
479+
ax_details.set_navigate(False)
480+
ax_details.patch.set_facecolor(_PLAYBACK_TRAY_FACE)
481+
ax_details.patch.set_alpha(0.92)
482+
ax_details.patch.set_edgecolor(_PLAYBACK_TRAY_FRAME)
483+
ax_details.patch.set_linewidth(0.6)
456484
for spine in ax_details.spines.values():
457-
spine.set_visible(False)
485+
spine.set_visible(True)
486+
spine.set_linewidth(0.6)
487+
spine.set_color(_PLAYBACK_TRAY_FRAME)
458488
text = ax_details.text(
459489
0.0,
460490
1.0,
@@ -507,6 +537,7 @@ def build_ui(self, *, initialize_step: bool = True) -> None:
507537
float(max(0, n)),
508538
valinit=float(self._initial_step if self._initial_step is not None else n),
509539
valstep=1,
540+
handle_style=_PLAYBACK_SLIDER_HANDLE_STYLE,
510541
)
511542
self.slider = slider
512543

@@ -687,6 +718,17 @@ def __init__(
687718

688719
def _build_controls(self) -> None:
689720
controls_ax = self.figure.add_axes(_CONTROLS_CHECKBOX_BOUNDS)
721+
controls_ax.set_xticks([])
722+
controls_ax.set_yticks([])
723+
controls_ax.set_navigate(False)
724+
controls_ax.patch.set_facecolor((0.97, 0.97, 0.99))
725+
controls_ax.patch.set_alpha(0.88)
726+
controls_ax.patch.set_edgecolor(_PLAYBACK_TRAY_FRAME)
727+
controls_ax.patch.set_linewidth(0.6)
728+
for spine in controls_ax.spines.values():
729+
spine.set_visible(True)
730+
spine.set_linewidth(0.6)
731+
spine.set_color(_PLAYBACK_TRAY_FRAME)
690732
self._controls_ax = controls_ax
691733
self._checkbuttons = CheckButtons(
692734
controls_ax,

src/tensor_network_viz/interactive_viewer.py

Lines changed: 112 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,61 @@
3232
)
3333
from ._tensor_elements_support import _TensorRecord
3434
from ._typing import root_figure
35-
from ._ui_utils import _reserve_figure_bottom, _set_axes_visible
35+
from ._ui_utils import _set_axes_visible, _set_figure_bottom_reserved
3636
from .config import EngineName, PlotConfig, ViewName
37+
from .contraction_viewer import _MAIN_FIGURE_BOTTOM_RESERVED, _PLAYBACK_DETAILS_TOP
3738
from .einsum_module.trace import EinsumTrace
3839
from .tensor_elements import _show_tensor_records
3940
from .tensor_elements_config import TensorElementsConfig
4041

4142
RenderedAxes = Axes | Axes3D
4243

43-
# 2d/3d: width 0.053 (= 60% of 0.088); height 0.063; bottom lowered ~0.03 vs earlier slider-row alignment.
44-
_VIEW_SELECTOR_BOUNDS: tuple[float, float, float, float] = (0.213, 0.025, 0.053, 0.063)
45-
_BASE_INTERACTIVE_CHECKBOX_BOUNDS: tuple[float, float, float, float] = (0.02, 0.028, 0.19, 0.09)
46-
_SCHEME_INTERACTIVE_CHECKBOX_BOUNDS: tuple[float, float, float, float] = (0.02, 0.028, 0.19, 0.142)
47-
_SCHEME_INSPECTOR_INTERACTIVE_CHECKBOX_BOUNDS: tuple[float, float, float, float] = (
44+
# Menu column: fixed bottom = tallest stack (inspector + scheme). Without playback, checkboxes/radio
45+
# stay as low as when the bottom row exists. Top aligned with cost-details top.
46+
_VIEW_SELECTOR_LEFT: float = 0.213
47+
_VIEW_SELECTOR_WIDTH: float = 0.053
48+
_VIEW_SELECTOR_HEIGHT: float = 0.063
49+
# Manual axes positions: 2D extends slightly below *base*, 3D starts higher (base + lift).
50+
_INTERACTIVE_2D_BOTTOM_EXTRA: float = 0.022
51+
_INTERACTIVE_3D_BOTTOM_LIFT: float = 0.084
52+
_BASE_INTERACTIVE_HEIGHT: float = 0.09
53+
_SCHEME_INSPECTOR_INTERACTIVE_HEIGHT: float = 0.172
54+
_INTERACTIVE_MENU_COLUMN_HEIGHT: float = _SCHEME_INSPECTOR_INTERACTIVE_HEIGHT
55+
_INTERACTIVE_MENU_COLUMN_BOTTOM: float = _PLAYBACK_DETAILS_TOP - _INTERACTIVE_MENU_COLUMN_HEIGHT
56+
_INTERACTIVE_CHECKBOX_AXES_BOUNDS: tuple[float, float, float, float] = (
4857
0.02,
49-
0.028,
58+
_INTERACTIVE_MENU_COLUMN_BOTTOM,
5059
0.19,
51-
0.172,
60+
_INTERACTIVE_MENU_COLUMN_HEIGHT,
61+
)
62+
# When Scheme is off, main axes bottom (not tied to menu column bottom after unifying menus).
63+
_SCHEME_OFF_FIGURE_BOTTOM_PAD: float = 0.02
64+
_MAIN_FIGURE_BOTTOM_SCHEME_OFF: float = (
65+
_PLAYBACK_DETAILS_TOP - _BASE_INTERACTIVE_HEIGHT + _SCHEME_OFF_FIGURE_BOTTOM_PAD
5266
)
53-
_INTERACTIVE_CONTROLS_BOTTOM: float = 0.26
5467
_BASE_TOGGLE_LABELS: tuple[str, str, str] = ("Hover", "Tensor labels", "Edge labels")
5568
_SCHEME_TOGGLE_LABELS: tuple[str, str, str] = ("Scheme", "Playback", "Costs")
5669
_TENSOR_INSPECTOR_LABEL: str = "Tensor inspector"
5770
_INTERACTIVE_LABEL_PROPS: dict[str, Sequence[Any]] = {"fontsize": [9.5]}
5871
_INTERACTIVE_CHECK_FRAME_PROPS: dict[str, float] = {"s": 44.0, "linewidth": 0.9}
5972
_INTERACTIVE_CHECK_MARK_PROPS: dict[str, float] = {"s": 34.0, "linewidth": 1.0}
6073
_INTERACTIVE_RADIO_PROPS: dict[str, float] = {"s": 38.0, "linewidth": 0.9}
74+
_CONTROL_TRAY_FACE: tuple[float, float, float] = (0.97, 0.97, 0.99)
75+
_CONTROL_TRAY_FRAME: tuple[float, float, float] = (0.78, 0.78, 0.82)
76+
77+
78+
def _style_interactive_control_axes(ax: Axes) -> None:
79+
ax.set_xticks([])
80+
ax.set_yticks([])
81+
ax.set_navigate(False)
82+
ax.patch.set_facecolor(_CONTROL_TRAY_FACE)
83+
ax.patch.set_alpha(0.88)
84+
ax.patch.set_edgecolor(_CONTROL_TRAY_FRAME)
85+
ax.patch.set_linewidth(0.6)
86+
for spine in ax.spines.values():
87+
spine.set_visible(True)
88+
spine.set_linewidth(0.6)
89+
spine.set_color(_CONTROL_TRAY_FRAME)
6190

6291

6392
@dataclass
@@ -222,11 +251,8 @@ def _interactive_checkbox_bounds(
222251
include_scheme_toggles: bool,
223252
include_tensor_inspector: bool,
224253
) -> tuple[float, float, float, float]:
225-
if include_scheme_toggles and include_tensor_inspector:
226-
return _SCHEME_INSPECTOR_INTERACTIVE_CHECKBOX_BOUNDS
227-
if include_scheme_toggles:
228-
return _SCHEME_INTERACTIVE_CHECKBOX_BOUNDS
229-
return _BASE_INTERACTIVE_CHECKBOX_BOUNDS
254+
_ = include_scheme_toggles, include_tensor_inspector
255+
return _INTERACTIVE_CHECKBOX_AXES_BOUNDS
230256

231257

232258
class _InteractiveTensorFigureController:
@@ -290,8 +316,6 @@ def initialize(self) -> tuple[Figure, RenderedAxes]:
290316
self.figure = figure
291317
if self._view_caches[self.current_view].scene is None:
292318
return figure, ax
293-
if not self._external_ax:
294-
_reserve_figure_bottom(figure, _INTERACTIVE_CONTROLS_BOTTOM)
295319
self._build_controls()
296320
self._apply_scene_state(self.current_scene)
297321
set_interactive_controls(figure, self)
@@ -346,6 +370,60 @@ def _build_view(
346370
scene.contraction_controls = get_contraction_controls(rendered_ax)
347371
return fig, rendered_ax
348372

373+
def _shared_data_axes_top(self) -> float:
374+
ax3 = self._view_caches["3d"].ax
375+
if ax3 is not None:
376+
p = ax3.get_position()
377+
return float(p.y0 + p.height)
378+
ax2 = self._view_caches["2d"].ax
379+
if ax2 is not None:
380+
p = ax2.get_position()
381+
return float(p.y0 + p.height)
382+
return 0.9
383+
384+
def _interactive_scheme_chrome_on(self) -> bool:
385+
return self.current_scene.contraction_controls is not None and self.scheme_on
386+
387+
def _interactive_main_axes_bottom(self) -> float:
388+
return float(
389+
_MAIN_FIGURE_BOTTOM_RESERVED
390+
if self._interactive_scheme_chrome_on()
391+
else _MAIN_FIGURE_BOTTOM_SCHEME_OFF
392+
)
393+
394+
def _figure_bottom_margin(self) -> float:
395+
base = self._interactive_main_axes_bottom()
396+
lows: list[float] = []
397+
if self._view_caches["2d"].ax is not None:
398+
lows.append(base - float(_INTERACTIVE_2D_BOTTOM_EXTRA))
399+
if self._view_caches["3d"].ax is not None:
400+
lows.append(base + float(_INTERACTIVE_3D_BOTTOM_LIFT))
401+
return min(lows) if lows else base
402+
403+
def _apply_interactive_figure_layout(self) -> None:
404+
if self.figure is None or self._external_ax:
405+
return
406+
_set_figure_bottom_reserved(self.figure, self._figure_bottom_margin())
407+
self._sync_data_axes_vertical_layout()
408+
409+
def _sync_data_axes_vertical_layout(self) -> None:
410+
if self.figure is None or self._external_ax:
411+
return
412+
base = self._interactive_main_axes_bottom()
413+
top = self._shared_data_axes_top()
414+
ax2 = self._view_caches["2d"].ax
415+
ax3 = self._view_caches["3d"].ax
416+
if ax2 is not None:
417+
bottom_2d = base - float(_INTERACTIVE_2D_BOTTOM_EXTRA)
418+
pos = ax2.get_position()
419+
height = max(top - bottom_2d, 0.08)
420+
ax2.set_position([pos.x0, bottom_2d, pos.width, height])
421+
if ax3 is not None:
422+
bottom_3d = base + float(_INTERACTIVE_3D_BOTTOM_LIFT)
423+
pos = ax3.get_position()
424+
height = max(top - bottom_3d, 0.08)
425+
ax3.set_position([pos.x0, bottom_3d, pos.width, height])
426+
349427
def _build_controls(self) -> None:
350428
assert self.figure is not None
351429
labels = list(_BASE_TOGGLE_LABELS)
@@ -355,8 +433,23 @@ def _build_controls(self) -> None:
355433
labels.extend(_SCHEME_TOGGLE_LABELS)
356434
if has_tensor_inspector:
357435
labels.append(_TENSOR_INSPECTOR_LABEL)
436+
cb_bounds = _interactive_checkbox_bounds(
437+
include_scheme_toggles=has_scheme_toggles,
438+
include_tensor_inspector=has_tensor_inspector,
439+
)
440+
cb_bottom = float(cb_bounds[1])
441+
check_ax = self.figure.add_axes(cb_bounds)
442+
_style_interactive_control_axes(check_ax)
443+
self._check_ax = check_ax
358444
if not self._external_ax:
359-
radio_ax = self.figure.add_axes(_VIEW_SELECTOR_BOUNDS)
445+
radio_bounds: tuple[float, float, float, float] = (
446+
_VIEW_SELECTOR_LEFT,
447+
cb_bottom,
448+
_VIEW_SELECTOR_WIDTH,
449+
_VIEW_SELECTOR_HEIGHT,
450+
)
451+
radio_ax = self.figure.add_axes(radio_bounds)
452+
_style_interactive_control_axes(radio_ax)
360453
self._radio_ax = radio_ax
361454
active_index = 0 if self.current_view == "2d" else 1
362455
self._radio = RadioButtons(
@@ -367,13 +460,6 @@ def _build_controls(self) -> None:
367460
radio_props=_INTERACTIVE_RADIO_PROPS,
368461
)
369462
self._radio.on_clicked(self._on_view_clicked)
370-
check_ax = self.figure.add_axes(
371-
_interactive_checkbox_bounds(
372-
include_scheme_toggles=has_scheme_toggles,
373-
include_tensor_inspector=has_tensor_inspector,
374-
)
375-
)
376-
self._check_ax = check_ax
377463
statuses = [
378464
self.hover_on,
379465
self.tensor_labels_on,
@@ -489,6 +575,8 @@ def _apply_scene_state(self, scene: _InteractiveSceneState) -> None:
489575
self._tensor_inspector.set_enabled(self.tensor_inspector_on)
490576
_apply_scene_hover_state(scene, hover_on=self.hover_on)
491577
self._sync_checkbuttons()
578+
if not self._external_ax:
579+
self._apply_interactive_figure_layout()
492580
scene.ax.figure.canvas.draw_idle()
493581

494582
def set_view(self, view: ViewName) -> None:

tests/test_engineering_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def test_compute_axis_directions_dense_dangling_chain_completes_quickly() -> Non
153153
def test_pyproject_declares_smoke_and_perf_markers() -> None:
154154
content = Path("pyproject.toml").read_text(encoding="utf-8")
155155

156-
assert 'markers = [' in content
156+
assert "markers = [" in content
157157
assert '"perf: runtime-sensitive regression checks and throughput guards"' in content
158158
assert '"smoke: lightweight render smoke checks"' in content
159159

tests/test_plotting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,9 @@ def test_show_tensor_network_places_view_selector_between_options_and_playback_s
626626

627627
assert check_bounds[2] <= 0.21
628628
assert radio_bounds[0] >= check_right - 0.02
629-
assert radio_bounds[2] <= 0.06
629+
assert radio_bounds[2] <= 0.09
630630
assert slider_bounds[0] >= radio_right - 0.02
631-
assert 0.02 <= radio_bounds[1] <= 0.06
631+
assert abs(radio_bounds[1] - check_bounds[1]) < 0.02
632632
assert play_bounds[0] > slider_right
633633

634634

0 commit comments

Comments
 (0)