Skip to content

Commit f586899

Browse files
FBumannclaude
andcommitted
fix: keep secondary traces visible in legend when combining figures
add_secondary_y was deduplicating legendgroups across both source figures, hiding the secondary's traces from the legend whenever the two figures shared legendgroup names (e.g. PX color= producing the same categories on both axes). Cross-source dedup is correct for overlay (same data, two display styles) but wrong for add_secondary_y, where each side plots different data on its own axis. Add a cross_source_dedup flag to _ensure_legend_visibility. overlay keeps the existing behavior; add_secondary_y opts out and instead namespaces colliding legendgroups with the source label and dedupes within each slice independently. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 44e599b commit f586899

2 files changed

Lines changed: 125 additions & 25 deletions

File tree

tests/test_figures.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,60 @@ def test_add_secondary_y_single_trace_with_names(self) -> None:
703703
assert combined.data[0].showlegend is True
704704
assert combined.data[1].showlegend is True
705705

706+
def test_add_secondary_y_multi_trace_shared_legendgroups(self) -> None:
707+
"""add_secondary_y must keep the secondary's traces visible in the legend
708+
even when both figures share legendgroup names (e.g. PX color=...)."""
709+
da1 = xr.DataArray(
710+
np.random.rand(10, 3),
711+
dims=["x", "cat"],
712+
coords={"cat": ["a", "b", "c"]},
713+
name="Var1",
714+
)
715+
da2 = xr.DataArray(
716+
np.random.rand(10, 3) * 100,
717+
dims=["x", "cat"],
718+
coords={"cat": ["a", "b", "c"]},
719+
name="Var2",
720+
)
721+
fig1 = xpx(da1).line()
722+
fig2 = xpx(da2).line()
723+
724+
combined = add_secondary_y(fig1, fig2)
725+
726+
# All 6 traces must end up visible in the legend with distinct legendgroups.
727+
assert all(t.showlegend is True for t in combined.data)
728+
legendgroups = [t.legendgroup for t in combined.data]
729+
assert len(set(legendgroups)) == len(legendgroups)
730+
# Secondary traces remain on y2.
731+
assert all(t.yaxis == "y" for t in combined.data[:3])
732+
assert all(t.yaxis == "y2" for t in combined.data[3:])
733+
734+
def test_add_secondary_y_after_overlay_keeps_secondary_visible(self) -> None:
735+
"""overlay → add_secondary_y must not hide the secondary's traces."""
736+
da1 = xr.DataArray(
737+
np.random.rand(10, 3),
738+
dims=["x", "cat"],
739+
coords={"cat": ["a", "b", "c"]},
740+
name="Var1",
741+
)
742+
da2 = xr.DataArray(
743+
np.random.rand(10, 3) * 100,
744+
dims=["x", "cat"],
745+
coords={"cat": ["a", "b", "c"]},
746+
name="Var2",
747+
)
748+
fig1 = xpx(da1).line()
749+
fig2 = xpx(da1).area()
750+
overlaid = overlay(fig1, fig2)
751+
fig3 = xpx(da2).line()
752+
753+
combined = add_secondary_y(overlaid, fig3)
754+
755+
# Secondary traces (last 3) must all be visible in the legend.
756+
for t in combined.data[-3:]:
757+
assert t.showlegend is True
758+
assert t.yaxis == "y2"
759+
706760
def test_overlay_faceted_legendgroup_dedup(self) -> None:
707761
"""Faceted overlay keeps only one showlegend=True per legendgroup."""
708762
da = xr.DataArray(

xarray_plotly/figures.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,40 @@ def _get_yaxis_title(fig: go.Figure) -> str:
2828
return ""
2929

3030

31+
def _dedup_legend_within_traces(traces: list[Any]) -> None:
32+
"""Ensure one ``showlegend=True`` per ``legendgroup`` among the given traces."""
33+
from collections import defaultdict
34+
35+
grouped: dict[str, list[Any]] = defaultdict(list)
36+
ungrouped: list[Any] = []
37+
38+
for trace in traces:
39+
lg = getattr(trace, "legendgroup", None) or ""
40+
if lg:
41+
grouped[lg].append(trace)
42+
else:
43+
ungrouped.append(trace)
44+
45+
for group_traces in grouped.values():
46+
has_visible = False
47+
for t in group_traces:
48+
if has_visible:
49+
t.showlegend = False
50+
elif getattr(t, "name", None):
51+
t.showlegend = True
52+
has_visible = True
53+
54+
for trace in ungrouped:
55+
if getattr(trace, "name", None):
56+
trace.showlegend = True
57+
58+
3159
def _ensure_legend_visibility(
3260
combined: go.Figure,
3361
source_figs: list[go.Figure],
3462
trace_slices: list[slice],
63+
*,
64+
cross_source_dedup: bool = True,
3565
) -> None:
3666
"""Fix legend visibility on a combined figure.
3767
@@ -43,13 +73,19 @@ def _ensure_legend_visibility(
4373
figures. We ensure at least one trace per ``legendgroup`` (or each
4474
ungrouped named trace) has ``showlegend=True``.
4575
3. **Duplicate legend entries** — when two source figures share the same
46-
``legendgroup`` names, we deduplicate so only the first trace per
47-
group shows in the legend.
76+
``legendgroup`` names and ``cross_source_dedup=True`` (the default),
77+
we deduplicate so only the first trace per group shows in the legend.
78+
When ``cross_source_dedup=False``, traces from different sources are
79+
kept independent: colliding ``legendgroup`` names are namespaced with
80+
the source label so each source's traces get their own legend entries.
4881
4982
Args:
5083
combined: The combined Plotly figure (mutated in place).
5184
source_figs: The original source figures, in trace order.
5285
trace_slices: Slices into ``combined.data`` for each source figure.
86+
cross_source_dedup: If True (overlay default), dedup legend entries
87+
across all sources. If False (add_secondary_y), preserve each
88+
source's legend entries independently.
5389
"""
5490
from collections import defaultdict
5591

@@ -70,30 +106,39 @@ def _ensure_legend_visibility(
70106
trace.legendgroup = label
71107

72108
# --- Step 2 & 3: fix showlegend per legendgroup -----------------------
73-
grouped: dict[str, list[Any]] = defaultdict(list)
74-
ungrouped: list[Any] = []
75-
76-
for trace in combined.data:
77-
lg = getattr(trace, "legendgroup", None) or ""
78-
if lg:
79-
grouped[lg].append(trace)
80-
else:
81-
ungrouped.append(trace)
82-
83-
for traces in grouped.values():
84-
has_visible = False
85-
for t in traces:
86-
if has_visible:
87-
# Deduplicate: only first keeps showlegend
88-
t.showlegend = False
89-
elif getattr(t, "name", None):
90-
t.showlegend = True
91-
has_visible = True
109+
if cross_source_dedup:
110+
_dedup_legend_within_traces(list(combined.data))
111+
else:
112+
# Namespace legendgroups that collide across slices, so each source
113+
# keeps its own legend entries instead of being deduped away.
114+
slice_groups: list[set[str]] = []
115+
for sl in trace_slices:
116+
slice_groups.append(
117+
{
118+
getattr(t, "legendgroup", None)
119+
for t in combined.data[sl]
120+
if getattr(t, "legendgroup", None)
121+
} # type: ignore[misc]
122+
)
123+
group_counts: dict[str, int] = defaultdict(int)
124+
for sg in slice_groups:
125+
for g in sg:
126+
group_counts[g] += 1
127+
colliding = {g for g, cnt in group_counts.items() if cnt > 1}
128+
129+
for label, sl in zip(labels, trace_slices, strict=False):
130+
if not label:
131+
continue
132+
for trace in combined.data[sl]:
133+
lg = getattr(trace, "legendgroup", None)
134+
if lg and lg in colliding:
135+
new_lg = f"{lg} ({label})"
136+
trace.legendgroup = new_lg
137+
if getattr(trace, "name", None) == lg:
138+
trace.name = new_lg
92139

93-
# Ungrouped traces with a name should show in the legend
94-
for trace in ungrouped:
95-
if getattr(trace, "name", None):
96-
trace.showlegend = True
140+
for sl in trace_slices:
141+
_dedup_legend_within_traces(list(combined.data[sl]))
97142

98143
# --- Step 4: propagate style properties to animation frame traces ------
99144
# When Plotly animates, frame trace data overwrites fig.data properties.
@@ -557,6 +602,7 @@ def add_secondary_y(
557602
combined,
558603
[base, secondary],
559604
[slice(0, base_n), slice(base_n, base_n + sec_n)],
605+
cross_source_dedup=False,
560606
)
561607
_fix_animation_axis_ranges(combined)
562608
return combined

0 commit comments

Comments
 (0)