Skip to content

Commit 686e565

Browse files
FBumannclaude
andcommitted
feat(plotting): add facet_titles kwarg to strip dim= prefix
Plotly Express renders faceted subplot titles as "<dim>=<value>" (e.g. "country=Brazil"). The new `facet_titles` keyword on every plotting/accessor method ("default" or "value") lets callers strip the prefix without a separate post-processing step. Default is "default" — no behavior change for existing users. Also adds a public `simplify_facet_titles(fig, mode)` helper for use on figures from any source (overlay/add_secondary_y outputs, raw PX figures, etc.). Both share the same Literal type alias `FacetTitlesMode` exported from `common`. Only annotations whose text starts with a Python-identifier prefix followed by `=` are touched, so user-added annotations are preserved. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c889e9a commit 686e565

6 files changed

Lines changed: 183 additions & 11 deletions

File tree

tests/test_figures.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import pytest
1010
import xarray as xr
1111

12-
from xarray_plotly import add_secondary_y, overlay, subplots, xpx
12+
from xarray_plotly import (
13+
add_secondary_y,
14+
overlay,
15+
simplify_facet_titles,
16+
subplots,
17+
xpx,
18+
)
1319

1420

1521
class TestOverlayBasic:
@@ -922,3 +928,64 @@ def test_source_not_modified(self) -> None:
922928
original_count = len(fig.data)
923929
_ = subplots(fig, fig, cols=2)
924930
assert len(fig.data) == original_count
931+
932+
933+
class TestSimplifyFacetTitles:
934+
"""Tests for the simplify_facet_titles helper and the `facet_titles` kwarg."""
935+
936+
@pytest.fixture(autouse=True)
937+
def setup(self) -> None:
938+
self.da = xr.DataArray(
939+
np.random.rand(10, 3),
940+
dims=["x", "country"],
941+
coords={"country": ["United States", "China", "Brazil"]},
942+
name="value",
943+
)
944+
945+
def test_helper_strips_dim_prefix(self) -> None:
946+
fig = xpx(self.da).line(facet_col="country")
947+
# PX writes annotations like "country=United States"
948+
original_texts = [a.text for a in fig.layout.annotations]
949+
assert any(t and t.startswith("country=") for t in original_texts)
950+
951+
simplify_facet_titles(fig)
952+
953+
for ann in fig.layout.annotations:
954+
if ann.text:
955+
assert "=" not in ann.text or ann.text.split("=", 1)[0] != "country"
956+
957+
def test_helper_full_is_noop(self) -> None:
958+
fig = xpx(self.da).line(facet_col="country")
959+
before = [a.text for a in fig.layout.annotations]
960+
simplify_facet_titles(fig, mode="default")
961+
after = [a.text for a in fig.layout.annotations]
962+
assert before == after
963+
964+
def test_helper_invalid_mode_raises(self) -> None:
965+
fig = xpx(self.da).line(facet_col="country")
966+
with pytest.raises(ValueError, match="facet_titles must be"):
967+
simplify_facet_titles(fig, mode="bogus") # type: ignore[arg-type]
968+
969+
def test_helper_leaves_user_annotations_alone(self) -> None:
970+
"""User-added annotations without a Python-identifier prefix are preserved."""
971+
fig = xpx(self.da).line(facet_col="country")
972+
fig.add_annotation(text="Some note", x=0, y=0, showarrow=False)
973+
simplify_facet_titles(fig)
974+
texts = [a.text for a in fig.layout.annotations]
975+
assert "Some note" in texts
976+
977+
def test_kwarg_default_keeps_px_format(self) -> None:
978+
fig = xpx(self.da).line(facet_col="country")
979+
# At least one annotation still carries the dim= prefix.
980+
assert any(a.text and a.text.startswith("country=") for a in fig.layout.annotations)
981+
982+
def test_kwarg_value_strips_prefix(self) -> None:
983+
fig = xpx(self.da).line(facet_col="country", facet_titles="value")
984+
for ann in fig.layout.annotations:
985+
if ann.text:
986+
# Should not start with "country="; the dim prefix is stripped.
987+
assert not ann.text.startswith("country=")
988+
989+
def test_kwarg_invalid_raises(self) -> None:
990+
with pytest.raises(ValueError, match="facet_titles must be"):
991+
xpx(self.da).line(facet_col="country", facet_titles="bogus") # type: ignore[arg-type]

xarray_plotly/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from xarray_plotly.figures import (
5757
add_secondary_y,
5858
overlay,
59+
simplify_facet_titles,
5960
subplots,
6061
update_traces,
6162
)
@@ -67,6 +68,7 @@
6768
"auto",
6869
"config",
6970
"overlay",
71+
"simplify_facet_titles",
7072
"subplots",
7173
"update_traces",
7274
"xpx",

xarray_plotly/accessor.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from xarray import DataArray, Dataset
77

88
from xarray_plotly import plotting
9-
from xarray_plotly.common import Colors, SlotValue, auto
9+
from xarray_plotly.common import Colors, FacetTitlesMode, SlotValue, auto
1010
from xarray_plotly.config import _options
1111

1212

@@ -54,6 +54,7 @@ def line(
5454
facet_row: SlotValue = auto,
5555
animation_frame: SlotValue = auto,
5656
colors: Colors = None,
57+
facet_titles: FacetTitlesMode = "default",
5758
**px_kwargs: Any,
5859
) -> go.Figure:
5960
"""Create an interactive line plot.
@@ -84,6 +85,7 @@ def line(
8485
facet_row=facet_row,
8586
animation_frame=animation_frame,
8687
colors=colors,
88+
facet_titles=facet_titles,
8789
**px_kwargs,
8890
)
8991

@@ -97,6 +99,7 @@ def bar(
9799
facet_row: SlotValue = auto,
98100
animation_frame: SlotValue = auto,
99101
colors: Colors = None,
102+
facet_titles: FacetTitlesMode = "default",
100103
**px_kwargs: Any,
101104
) -> go.Figure:
102105
"""Create an interactive bar chart.
@@ -125,6 +128,7 @@ def bar(
125128
facet_row=facet_row,
126129
animation_frame=animation_frame,
127130
colors=colors,
131+
facet_titles=facet_titles,
128132
**px_kwargs,
129133
)
130134

@@ -138,6 +142,7 @@ def area(
138142
facet_row: SlotValue = auto,
139143
animation_frame: SlotValue = auto,
140144
colors: Colors = None,
145+
facet_titles: FacetTitlesMode = "default",
141146
**px_kwargs: Any,
142147
) -> go.Figure:
143148
"""Create an interactive stacked area chart.
@@ -166,6 +171,7 @@ def area(
166171
facet_row=facet_row,
167172
animation_frame=animation_frame,
168173
colors=colors,
174+
facet_titles=facet_titles,
169175
**px_kwargs,
170176
)
171177

@@ -178,6 +184,7 @@ def fast_bar(
178184
facet_row: SlotValue = auto,
179185
animation_frame: SlotValue = auto,
180186
colors: Colors = None,
187+
facet_titles: FacetTitlesMode = "default",
181188
**px_kwargs: Any,
182189
) -> go.Figure:
183190
"""Create a bar-like chart using stacked areas for better performance.
@@ -204,6 +211,7 @@ def fast_bar(
204211
facet_row=facet_row,
205212
animation_frame=animation_frame,
206213
colors=colors,
214+
facet_titles=facet_titles,
207215
**px_kwargs,
208216
)
209217

@@ -218,6 +226,7 @@ def scatter(
218226
facet_row: SlotValue = auto,
219227
animation_frame: SlotValue = auto,
220228
colors: Colors = None,
229+
facet_titles: FacetTitlesMode = "default",
221230
**px_kwargs: Any,
222231
) -> go.Figure:
223232
"""Create an interactive scatter plot.
@@ -252,6 +261,7 @@ def scatter(
252261
facet_row=facet_row,
253262
animation_frame=animation_frame,
254263
colors=colors,
264+
facet_titles=facet_titles,
255265
**px_kwargs,
256266
)
257267

@@ -264,6 +274,7 @@ def box(
264274
facet_row: SlotValue = None,
265275
animation_frame: SlotValue = None,
266276
colors: Colors = None,
277+
facet_titles: FacetTitlesMode = "default",
267278
**px_kwargs: Any,
268279
) -> go.Figure:
269280
"""Create an interactive box plot.
@@ -293,6 +304,7 @@ def box(
293304
facet_row=facet_row,
294305
animation_frame=animation_frame,
295306
colors=colors,
307+
facet_titles=facet_titles,
296308
**px_kwargs,
297309
)
298310

@@ -305,6 +317,7 @@ def imshow(
305317
animation_frame: SlotValue = auto,
306318
robust: bool = False,
307319
colors: Colors = None,
320+
facet_titles: FacetTitlesMode = "default",
308321
**px_kwargs: Any,
309322
) -> go.Figure:
310323
"""Create an interactive heatmap image.
@@ -337,6 +350,7 @@ def imshow(
337350
animation_frame=animation_frame,
338351
robust=robust,
339352
colors=colors,
353+
facet_titles=facet_titles,
340354
**px_kwargs,
341355
)
342356

@@ -348,6 +362,7 @@ def pie(
348362
facet_col: SlotValue = auto,
349363
facet_row: SlotValue = auto,
350364
colors: Colors = None,
365+
facet_titles: FacetTitlesMode = "default",
351366
**px_kwargs: Any,
352367
) -> go.Figure:
353368
"""Create an interactive pie chart.
@@ -372,6 +387,7 @@ def pie(
372387
facet_col=facet_col,
373388
facet_row=facet_row,
374389
colors=colors,
390+
facet_titles=facet_titles,
375391
**px_kwargs,
376392
)
377393

@@ -452,6 +468,7 @@ def line(
452468
facet_row: SlotValue = auto,
453469
animation_frame: SlotValue = auto,
454470
colors: Colors = None,
471+
facet_titles: FacetTitlesMode = "default",
455472
**px_kwargs: Any,
456473
) -> go.Figure:
457474
"""Create an interactive line plot.
@@ -482,6 +499,7 @@ def line(
482499
facet_row=facet_row,
483500
animation_frame=animation_frame,
484501
colors=colors,
502+
facet_titles=facet_titles,
485503
**px_kwargs,
486504
)
487505

@@ -496,6 +514,7 @@ def bar(
496514
facet_row: SlotValue = auto,
497515
animation_frame: SlotValue = auto,
498516
colors: Colors = None,
517+
facet_titles: FacetTitlesMode = "default",
499518
**px_kwargs: Any,
500519
) -> go.Figure:
501520
"""Create an interactive bar chart.
@@ -524,6 +543,7 @@ def bar(
524543
facet_row=facet_row,
525544
animation_frame=animation_frame,
526545
colors=colors,
546+
facet_titles=facet_titles,
527547
**px_kwargs,
528548
)
529549

@@ -538,6 +558,7 @@ def area(
538558
facet_row: SlotValue = auto,
539559
animation_frame: SlotValue = auto,
540560
colors: Colors = None,
561+
facet_titles: FacetTitlesMode = "default",
541562
**px_kwargs: Any,
542563
) -> go.Figure:
543564
"""Create an interactive stacked area chart.
@@ -566,6 +587,7 @@ def area(
566587
facet_row=facet_row,
567588
animation_frame=animation_frame,
568589
colors=colors,
590+
facet_titles=facet_titles,
569591
**px_kwargs,
570592
)
571593

@@ -579,6 +601,7 @@ def fast_bar(
579601
facet_row: SlotValue = auto,
580602
animation_frame: SlotValue = auto,
581603
colors: Colors = None,
604+
facet_titles: FacetTitlesMode = "default",
582605
**px_kwargs: Any,
583606
) -> go.Figure:
584607
"""Create a bar-like chart using stacked areas for better performance.
@@ -605,6 +628,7 @@ def fast_bar(
605628
facet_row=facet_row,
606629
animation_frame=animation_frame,
607630
colors=colors,
631+
facet_titles=facet_titles,
608632
**px_kwargs,
609633
)
610634

@@ -620,6 +644,7 @@ def scatter(
620644
facet_row: SlotValue = auto,
621645
animation_frame: SlotValue = auto,
622646
colors: Colors = None,
647+
facet_titles: FacetTitlesMode = "default",
623648
**px_kwargs: Any,
624649
) -> go.Figure:
625650
"""Create an interactive scatter plot.
@@ -650,6 +675,7 @@ def scatter(
650675
facet_row=facet_row,
651676
animation_frame=animation_frame,
652677
colors=colors,
678+
facet_titles=facet_titles,
653679
**px_kwargs,
654680
)
655681

@@ -663,6 +689,7 @@ def box(
663689
facet_row: SlotValue = None,
664690
animation_frame: SlotValue = None,
665691
colors: Colors = None,
692+
facet_titles: FacetTitlesMode = "default",
666693
**px_kwargs: Any,
667694
) -> go.Figure:
668695
"""Create an interactive box plot.
@@ -689,6 +716,7 @@ def box(
689716
facet_row=facet_row,
690717
animation_frame=animation_frame,
691718
colors=colors,
719+
facet_titles=facet_titles,
692720
**px_kwargs,
693721
)
694722

@@ -701,6 +729,7 @@ def pie(
701729
facet_col: SlotValue = auto,
702730
facet_row: SlotValue = auto,
703731
colors: Colors = None,
732+
facet_titles: FacetTitlesMode = "default",
704733
**px_kwargs: Any,
705734
) -> go.Figure:
706735
"""Create an interactive pie chart.
@@ -725,5 +754,6 @@ def pie(
725754
facet_col=facet_col,
726755
facet_row=facet_row,
727756
colors=colors,
757+
facet_titles=facet_titles,
728758
**px_kwargs,
729759
)

xarray_plotly/common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import functools
66
import warnings
77
from collections.abc import Hashable, Mapping, Sequence
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
import plotly.express as px
1111

@@ -39,6 +39,13 @@ def __repr__(self) -> str:
3939
- None: Use Plotly defaults
4040
"""
4141

42+
FacetTitlesMode = Literal["value", "default"]
43+
"""Type alias for facet_titles parameter.
44+
45+
- "default" (default): keep PX's ``"<dim>=<value>"`` subplot titles.
46+
- "value": strip the ``<dim>=`` prefix, leaving just the value.
47+
"""
48+
4249
# Re-export for backward compatibility
4350
SLOT_ORDERS = DEFAULT_SLOT_ORDERS
4451
"""Slot orders per plot type.

0 commit comments

Comments
 (0)