Skip to content

Commit 3ebefe1

Browse files
authored
Raise clear error for 3D (z-stack) images and labels (#675)
1 parent 7cb3133 commit 3ebefe1

3 files changed

Lines changed: 53 additions & 3 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import abc
55
from collections.abc import Sequence
66
from copy import copy
7-
from typing import Any
7+
from typing import Any, Literal
88

99
import dask
1010
import dask.dataframe as dd
@@ -80,6 +80,24 @@
8080
_Normalize = Normalize | abc.Sequence[Normalize]
8181

8282

83+
def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray:
84+
if isinstance(element, DataTree):
85+
return next(iter(next(iter(element.values())).data_vars.values()))
86+
return element
87+
88+
89+
def _guard_2d_only(element: xr.DataArray | DataTree, element_name: str, kind: Literal["images", "labels"]) -> None:
90+
top = _get_top_data_array(element)
91+
if "z" in top.dims:
92+
z_size = top.sizes["z"]
93+
raise ValueError(
94+
f"render_{kind} does not support 3D {kind}. Element '{element_name}' has a 'z' dimension "
95+
f"with {z_size} slices. Select a 2D slice before plotting:\n"
96+
f" sdata['{element_name}'].isel(z=0)\n"
97+
"or use sd.bounding_box_query() to extract a 2D region."
98+
)
99+
100+
83101
def _want_decorations(color_vector: Any, na_color: Color) -> bool:
84102
"""Return whether legend/colorbar decorations should be shown.
85103
@@ -1247,6 +1265,7 @@ def _render_images(
12471265

12481266
palette = render_params.palette
12491267
img = sdata_filt[render_params.element]
1268+
_guard_2d_only(img, render_params.element, "images")
12501269
extent = get_extent(img, coordinate_system=coordinate_system)
12511270
scale = render_params.scale
12521271

@@ -1674,6 +1693,7 @@ def _render_labels(
16741693
)
16751694

16761695
label = sdata_filt.labels[element]
1696+
_guard_2d_only(label, element, "labels")
16771697
extent = get_extent(label, coordinate_system=coordinate_system)
16781698

16791699
# get best scale out of multiscale label

tests/pl/test_render_images.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from matplotlib.colors import LogNorm, Normalize
88
from spatial_image import to_spatial_image
99
from spatialdata import SpatialData
10-
from spatialdata.models import Image2DModel
10+
from spatialdata.models import Image2DModel, Image3DModel
1111

1212
import spatialdata_plot # noqa: F401
1313
from spatialdata_plot._logging import logger, logger_warns
@@ -720,6 +720,21 @@ def test_channels_as_legend_coexists_with_other_elements(self, sdata_blobs: Spat
720720
plt.close("all")
721721

722722

723+
@pytest.mark.parametrize("scale_factors", [None, [2]])
724+
def test_render_images_raises_on_3d(scale_factors):
725+
# Regression test for #608: 3D images must raise a clear ValueError, not crash
726+
# deep in matplotlib with "Invalid shape" / opaque numpy errors.
727+
img = np.random.default_rng(0).random((2, 4, 16, 16), dtype=np.float32)
728+
image3d = Image3DModel.parse(img, dims=["c", "z", "y", "x"], c_coords=["DAPI", "GFP"], scale_factors=scale_factors)
729+
sdata = SpatialData(images={"img3d": image3d})
730+
fig, ax = plt.subplots()
731+
try:
732+
with pytest.raises(ValueError, match=r"render_images does not support 3D.*img3d.*z.*4"):
733+
sdata.pl.render_images("img3d").pl.show(ax=ax)
734+
finally:
735+
plt.close(fig)
736+
737+
723738
def test_lognorm_with_zeros_suppresses_colorbar_with_warning():
724739
# regression test for #604: LogNorm + non-positive data must not raise an opaque
725740
# matplotlib ValueError; instead suppress the colorbar with an actionable UserWarning.

tests/pl/test_render_labels.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from matplotlib.colors import Normalize
1010
from spatial_image import to_spatial_image
1111
from spatialdata import SpatialData, deepcopy, get_element_instances
12-
from spatialdata.models import Labels2DModel, TableModel
12+
from spatialdata.models import Labels2DModel, Labels3DModel, TableModel
1313

1414
import spatialdata_plot # noqa: F401
1515
from spatialdata_plot._logging import logger, logger_warns
@@ -550,3 +550,18 @@ def test_render_labels_disjoint_instance_ids_clear_error():
550550
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
551551
finally:
552552
plt.close(fig)
553+
554+
555+
@pytest.mark.parametrize("scale_factors", [None, [2]])
556+
def test_render_labels_raises_on_3d(scale_factors):
557+
# Regression test for #608: 3D labels must raise a clear ValueError, not crash
558+
# deep in numpy with an opaque concatenation error.
559+
arr = np.random.default_rng(0).integers(0, 5, size=(4, 16, 16), dtype=np.int32)
560+
labels3d = Labels3DModel.parse(arr, dims=["z", "y", "x"], scale_factors=scale_factors)
561+
sdata = SpatialData(labels={"lbl3d": labels3d})
562+
fig, ax = plt.subplots()
563+
try:
564+
with pytest.raises(ValueError, match=r"render_labels does not support 3D.*lbl3d.*z.*4"):
565+
sdata.pl.render_labels("lbl3d").pl.show(ax=ax)
566+
finally:
567+
plt.close(fig)

0 commit comments

Comments
 (0)