Skip to content

Commit 23feadb

Browse files
committed
Optimize interactive render hot paths
1 parent 31ac084 commit 23feadb

15 files changed

Lines changed: 951 additions & 110 deletions

src/tensor_network_viz/_core/draw/contraction_edges.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .constants import _CURVE_NEAR_PAIR_REF, _CURVE_OFFSET_FACTOR
1414
from .edge_labels import _plot_contraction_index_captions
1515
from .fonts_and_scale import _DrawScaleParams
16+
from .label_descriptors import _TextLabelDescriptor
1617
from .labels_misc import _contraction_hover_label_text
1718
from .plotter import _PlotAdapter
1819
from .scene_state import _RenderedEdgeGeometry
@@ -113,6 +114,7 @@ def _draw_contraction_edge_labels(
113114
ax: Any,
114115
scale: float,
115116
zorder_label: float | None = None,
117+
label_sink: list[_TextLabelDescriptor] | None = None,
116118
) -> None:
117119
hover_targets = getattr(plotter, "_hover_edge_targets", None)
118120
if config.hover_labels and hover_targets is not None:
@@ -139,6 +141,7 @@ def _draw_contraction_edge_labels(
139141
ax=ax,
140142
scale=scale,
141143
zorder_label=zorder_label,
144+
label_sink=label_sink,
142145
)
143146

144147

src/tensor_network_viz/_core/draw/dangling_self_edges.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_node_label_clearance,
2323
_self_loop_hover_label_text,
2424
)
25+
from .label_descriptors import _TextLabelDescriptor
2526
from .plotter import _PlotAdapter
2627
from .scene_state import _RenderedEdgeGeometry
2728
from .vectors import _perpendicular_2d
@@ -119,6 +120,7 @@ def _draw_dangling_edge_labels(
119120
ax: Any,
120121
scale: float,
121122
zorder_label: float | None = None,
123+
label_sink: list[_TextLabelDescriptor] | None = None,
122124
) -> None:
123125
if not edge.label and not config.hover_labels:
124126
return
@@ -151,6 +153,11 @@ def _draw_dangling_edge_labels(
151153
return
152154

153155
raw_label = edge.label
156+
stub_segment = np.stack([np.asarray(start, dtype=float), np.asarray(end, dtype=float)], axis=0)
157+
stub_length = _polyline_arc_length_total(stub_segment)
158+
distance_from_tip = float(_PHYS_DANGLING_2D_FRAC_FROM_TIP) * stub_length
159+
point, tangent = _point_tangent_along_polyline_from_end(stub_segment, distance_from_tip)
160+
154161
fontsize = _edge_index_fontsize_for_bond(
155162
raw_label,
156163
bond_start=start,
@@ -166,10 +173,6 @@ def _draw_dangling_edge_labels(
166173
bbox_pad=p.index_bbox_pad,
167174
zorder=zorder_label,
168175
)
169-
stub_segment = np.stack([np.asarray(start, dtype=float), np.asarray(end, dtype=float)], axis=0)
170-
stub_length = _polyline_arc_length_total(stub_segment)
171-
distance_from_tip = float(_PHYS_DANGLING_2D_FRAC_FROM_TIP) * stub_length
172-
point, tangent = _point_tangent_along_polyline_from_end(stub_segment, distance_from_tip)
173176
if dimensions == 2:
174177
start_2d = np.asarray(start[:2], dtype=float)
175178
end_2d = np.asarray(end[:2], dtype=float)
@@ -203,11 +206,18 @@ def _draw_dangling_edge_labels(
203206
scale=scale,
204207
fontsize_pt=float(fontsize),
205208
)
206-
plotter.plot_text(
207-
label_pos,
208-
format_tensor_node_label(raw_label),
209-
**{**text_kwargs, **align_kwargs},
210-
)
209+
formatted = format_tensor_node_label(raw_label)
210+
kwargs = {**text_kwargs, **align_kwargs}
211+
if label_sink is not None:
212+
label_sink.append(
213+
_TextLabelDescriptor(
214+
position=np.asarray(label_pos, dtype=float).copy(),
215+
text=formatted,
216+
kwargs=dict(kwargs),
217+
)
218+
)
219+
return
220+
plotter.plot_text(label_pos, formatted, **kwargs)
211221

212222

213223
def _draw_self_loop_edge(
@@ -302,6 +312,7 @@ def _draw_self_loop_edge_labels(
302312
ax: Any,
303313
scale: float,
304314
zorder_label: float | None = None,
315+
label_sink: list[_TextLabelDescriptor] | None = None,
305316
) -> None:
306317
hover_targets = getattr(plotter, "_hover_edge_targets", None)
307318
if config.hover_labels and hover_targets is not None:
@@ -391,11 +402,22 @@ def _draw_self_loop_edge_labels(
391402
dimensions=dimensions,
392403
),
393404
}
394-
plotter.plot_text(
395-
np.asarray(q_a, dtype=float) + offset_a,
396-
format_tensor_node_label(caption_a),
397-
**text_kwargs_a,
398-
)
405+
position_a = np.asarray(q_a, dtype=float) + offset_a
406+
formatted_a = format_tensor_node_label(caption_a)
407+
if label_sink is not None:
408+
label_sink.append(
409+
_TextLabelDescriptor(
410+
position=np.asarray(position_a, dtype=float).copy(),
411+
text=formatted_a,
412+
kwargs=dict(text_kwargs_a),
413+
)
414+
)
415+
else:
416+
plotter.plot_text(
417+
position_a,
418+
formatted_a,
419+
**text_kwargs_a,
420+
)
399421
if caption_b:
400422
offset_b = (
401423
-direction_unit
@@ -433,9 +455,20 @@ def _draw_self_loop_edge_labels(
433455
dimensions=dimensions,
434456
),
435457
}
458+
formatted_b = format_tensor_node_label(caption_b)
459+
position_b = np.asarray(q_b, dtype=float) + offset_b
460+
if label_sink is not None:
461+
label_sink.append(
462+
_TextLabelDescriptor(
463+
position=np.asarray(position_b, dtype=float).copy(),
464+
text=formatted_b,
465+
kwargs=dict(text_kwargs_b),
466+
)
467+
)
468+
return
436469
plotter.plot_text(
437-
np.asarray(q_b, dtype=float) + offset_b,
438-
format_tensor_node_label(caption_b),
470+
position_b,
471+
formatted_b,
439472
**text_kwargs_b,
440473
)
441474

src/tensor_network_viz/_core/draw/edge_labels.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from ..layout import NodePositions
1818
from .fonts_and_scale import _DrawScaleParams
19+
from .label_descriptors import _TextLabelDescriptor
1920
from .labels_misc import _edge_index_text_kwargs
2021
from .plotter import _PlotAdapter
2122
from .viewport_geometry import (
@@ -46,6 +47,7 @@ def _plot_contraction_index_captions(
4647
ax: Any,
4748
scale: float,
4849
zorder_label: float | None = None,
50+
label_sink: list[_TextLabelDescriptor] | None = None,
4951
) -> None:
5052
ep_l, ep_r = _require_contraction_endpoints(edge)
5153
cap_l: str | None = _endpoint_index_caption(ep_l, edge, graph)
@@ -129,11 +131,18 @@ def _plot_contraction_index_captions(
129131
scale=scale,
130132
fontsize_pt=float(fontsize),
131133
)
132-
plotter.plot_text(
133-
position,
134-
format_tensor_node_label(cap),
135-
**{**text_kwargs, **align_kwargs},
136-
)
134+
formatted = format_tensor_node_label(cap)
135+
kwargs = {**text_kwargs, **align_kwargs}
136+
if label_sink is not None:
137+
label_sink.append(
138+
_TextLabelDescriptor(
139+
position=np.asarray(position, dtype=float).copy(),
140+
text=formatted,
141+
kwargs=dict(kwargs),
142+
)
143+
)
144+
continue
145+
plotter.plot_text(position, formatted, **kwargs)
137146

138147

139148
__all__ = ["_plot_contraction_index_captions"]

src/tensor_network_viz/_core/draw/fonts_and_scale.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,17 @@ def _on_2d_limits_changed(ax: Axes) -> None:
152152

153153

154154
def _register_2d_zoom_font_scaling(ax: Axes) -> None:
155-
old_cids = get_zoom_cids(ax)
156-
for cid in old_cids:
157-
with suppress(ValueError, KeyError):
158-
ax.callbacks.disconnect(cid)
159155
x0, x1 = ax.get_xlim()
160156
y0, y1 = ax.get_ylim()
161157
ref_span = max(float(x1 - x0), float(y1 - y0), 1e-9)
162158
sizes = {t: float(t.get_fontsize()) for t in ax.texts}
163159
set_zoom_font_state(ax, ref_span=ref_span, sizes=sizes)
164160

161+
old_cids = get_zoom_cids(ax)
162+
if old_cids:
163+
_on_2d_limits_changed(ax)
164+
return
165+
165166
def _cb(_: object) -> None:
166167
_on_2d_limits_changed(ax)
167168

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Literal, TypeAlias
5+
6+
import numpy as np
7+
8+
9+
@dataclass(frozen=True)
10+
class _TextLabelDescriptor:
11+
position: np.ndarray
12+
text: str
13+
kwargs: dict[str, Any]
14+
node_id: int | None = None
15+
16+
17+
@dataclass(frozen=True)
18+
class _DeferredBondLabelDescriptor:
19+
text: str
20+
point: np.ndarray
21+
tangent_geom: np.ndarray
22+
tangent_align: np.ndarray
23+
bond_start: np.ndarray
24+
bond_end: np.ndarray
25+
text_endpoint: Literal["left", "right"]
26+
stub_kind: Literal["bond", "dangling"]
27+
is_physical: bool
28+
peer_captions_for_width: tuple[str, ...] | None = None
29+
zorder: float | None = None
30+
31+
32+
@dataclass(frozen=True)
33+
class _DeferredSelfLoopLabelDescriptor:
34+
text: str
35+
point: np.ndarray
36+
tangent: np.ndarray
37+
bond_start: np.ndarray
38+
bond_end: np.ndarray
39+
offset_direction: np.ndarray
40+
offset_scale: float
41+
text_endpoint: Literal["left", "right"]
42+
peer_captions_for_width: tuple[str, ...] | None = None
43+
zorder: float | None = None
44+
45+
46+
_AnyLabelDescriptor: TypeAlias = (
47+
_TextLabelDescriptor | _DeferredBondLabelDescriptor | _DeferredSelfLoopLabelDescriptor
48+
)
49+
50+
51+
__all__ = [
52+
"_AnyLabelDescriptor",
53+
"_DeferredBondLabelDescriptor",
54+
"_DeferredSelfLoopLabelDescriptor",
55+
"_TextLabelDescriptor",
56+
]

src/tensor_network_viz/_core/draw/scene_state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..layout import AxisDirections, NodePositions
1313
from .fonts_and_scale import _DrawScaleParams
1414
from .hover import _RenderHoverState
15+
from .label_descriptors import _AnyLabelDescriptor, _TextLabelDescriptor
1516
from .plotter import _PlotAdapter
1617

1718

@@ -40,6 +41,8 @@ class _InteractiveSceneState:
4041
tensor_disk_radius_px_3d: float | None
4142
tensor_label_artists: list[Artist] = field(default_factory=list)
4243
edge_label_artists: list[Artist] = field(default_factory=list)
44+
tensor_label_descriptors: tuple[_AnyLabelDescriptor, ...] | None = None
45+
edge_label_descriptors: tuple[_AnyLabelDescriptor, ...] | None = None
4346
tensor_hover_payload: dict[int, tuple[str, float]] | None = None
4447
edge_hover_payload: tuple[tuple[np.ndarray, str], ...] | None = None
4548
contraction_controls: Any = None
@@ -48,4 +51,6 @@ class _InteractiveSceneState:
4851
__all__ = [
4952
"_InteractiveSceneState",
5053
"_RenderedEdgeGeometry",
54+
"_AnyLabelDescriptor",
55+
"_TextLabelDescriptor",
5156
]

src/tensor_network_viz/_core/draw/tensors.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from .disk_metrics import _tensor_disk_radius_px
2424
from .fonts_and_scale import _DrawScaleParams
25+
from .label_descriptors import _TextLabelDescriptor
2526
from .plotter import _PlotAdapter, _visible_degree_one_mask
2627
from .viewport_geometry import _stack_visible_tensor_coords
2728

@@ -97,8 +98,9 @@ def _draw_labels(
9798
visible_draw_order: list[int] | None = None,
9899
tensor_label_zorder_by_node: dict[int, float] | None = None,
99100
tensor_disk_radius_px_3d: float | None = None,
101+
label_sink: list[_TextLabelDescriptor] | None = None,
100102
) -> None:
101-
if show_tensor_labels or tensor_hover_by_node is not None:
103+
if show_tensor_labels or tensor_hover_by_node is not None or label_sink is not None:
102104
fig = ax.figure
103105
ordered_ids: list[int]
104106
if visible_draw_order is not None:
@@ -124,14 +126,30 @@ def _draw_labels(
124126
if dimensions == 3:
125127
cap_tensor = float(p.font_tensor_label_max) * _LABEL_FONT_3D_SCALE
126128
fs = min(float(fs) * _LABEL_FONT_3D_SCALE, cap_tensor)
127-
if tensor_hover_by_node is not None:
128-
tensor_hover_by_node[node_id] = (display_name, float(fs))
129-
if not show_tensor_labels:
130-
continue
131129
if tensor_label_zorder_by_node is None:
132130
z_lbl = float(_ZORDER_TENSOR_NAME)
133131
else:
134132
z_lbl = float(tensor_label_zorder_by_node.get(node_id, _ZORDER_TENSOR_NAME))
133+
if tensor_hover_by_node is not None:
134+
tensor_hover_by_node[node_id] = (display_name, float(fs))
135+
if label_sink is not None:
136+
label_sink.append(
137+
_TextLabelDescriptor(
138+
position=np.asarray(pos, dtype=float).copy(),
139+
text=display_name,
140+
kwargs={
141+
"color": config.tensor_label_color,
142+
"ha": "center",
143+
"va": "center",
144+
"fontsize": float(fs),
145+
"zorder": float(z_lbl),
146+
"gid": _TENSOR_LABEL_GID,
147+
},
148+
node_id=int(node_id),
149+
)
150+
)
151+
if not show_tensor_labels:
152+
continue
135153
plotter.plot_text(
136154
pos,
137155
display_name,

0 commit comments

Comments
 (0)