Skip to content

Commit e5c54fd

Browse files
authored
Infer fig from ax in show() instead of requiring it (#655)
1 parent 14b9ffe commit e5c54fd

3 files changed

Lines changed: 34 additions & 13 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
import warnings
66
from collections import OrderedDict
7-
from collections.abc import Callable, Sequence
7+
from collections.abc import Callable
88
from copy import deepcopy
99
from pathlib import Path
1010
from typing import Any, Literal, cast
@@ -1003,7 +1003,7 @@ def show(
10031003
legend_params,
10041004
)
10051005

1006-
if fig is not None and not isinstance(ax, Sequence):
1006+
if fig is not None:
10071007
warnings.warn(
10081008
"`fig` is being deprecated as an argument to `PlotAccessor.show` in spatialdata-plot. "
10091009
"To use a custom figure, create axes from it and pass them via `ax` instead: "

src/spatialdata_plot/pl/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,22 +294,24 @@ def _prepare_params_plot(
294294
elif num_panels > 1:
295295
if not isinstance(ax, Sequence):
296296
raise TypeError(f"Expected `ax` to be a `Sequence`, but got {type(ax).__name__}")
297-
if ax is not None and len(ax) != num_panels:
297+
if len(ax) != num_panels:
298298
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
299299
if fig is None:
300-
# TODO(#579): infer fig from ax[0].get_figure() instead of requiring it
301-
raise ValueError(
302-
f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified."
303-
)
304-
assert ax is None or isinstance(ax, Sequence), f"Invalid type of `ax`: {type(ax)}, expected `Sequence`."
300+
fig = ax[0].get_figure()
305301
axs = ax
306302
if dpi is not None:
307303
fig.set_dpi(dpi)
308304
else:
309305
axs = None
310306
if ax is None:
311307
fig, ax = plt.subplots(figsize=figsize, dpi=resolved_dpi, constrained_layout=True)
312-
elif isinstance(ax, Axes):
308+
else:
309+
if isinstance(ax, Sequence):
310+
if len(ax) != 1:
311+
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
312+
ax = ax[0]
313+
if not isinstance(ax, Axes):
314+
raise TypeError(f"Expected `ax` to be an `Axes` or a `Sequence` of `Axes`, but got {type(ax).__name__}")
313315
fig = ax.get_figure()
314316
if dpi is not None:
315317
fig.set_dpi(dpi)

tests/pl/test_show.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,35 @@ def test_fig_parameter_default_no_warning(sdata_blobs: SpatialData):
136136
plt.close("all")
137137

138138

139-
def test_fig_parameter_no_warning_with_ax_list(sdata_blobs: SpatialData):
140-
"""Passing fig= with a list of axes should not warn (fig is still required there)."""
139+
def test_fig_parameter_warns_with_ax_list(sdata_blobs: SpatialData):
140+
"""Passing fig= alongside a list of axes should also emit the deprecation (regression for #625)."""
141141
set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs")
142142
fig, axs = plt.subplots(1, 2)
143-
with warnings.catch_warnings():
144-
warnings.simplefilter("error", DeprecationWarning)
143+
with pytest.warns(DeprecationWarning, match="`fig` is being deprecated"):
145144
sdata_blobs.pl.render_images(element="blobs_image").pl.show(fig=fig, ax=list(axs), show=False)
146145
plt.close("all")
147146

148147

148+
def test_show_ax_list_infers_fig(sdata_blobs: SpatialData):
149+
"""show(ax=[...]) should infer fig from the axes without requiring fig= (regression for #625)."""
150+
set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs")
151+
fig, axs = plt.subplots(1, 2)
152+
sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=list(axs), show=False)
153+
for ax in axs:
154+
assert ax.get_figure() is fig
155+
assert len(ax.get_images()) > 0
156+
plt.close(fig)
157+
158+
159+
def test_show_single_panel_accepts_ax_list(sdata_blobs: SpatialData):
160+
"""show(ax=[ax]) for a single coordinate system should be accepted (regression for #625)."""
161+
fig, ax = plt.subplots()
162+
sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=[ax], show=False)
163+
assert ax.get_figure() is fig
164+
assert len(ax.get_images()) > 0
165+
plt.close(fig)
166+
167+
149168
def test_frameon_false_multi_panel(sdata_blobs: SpatialData):
150169
"""frameon=False should apply to all panels in a multi-panel plot (regression for #204)."""
151170
set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs")

0 commit comments

Comments
 (0)