Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import abc
from collections.abc import Sequence
from copy import copy
from typing import Any
from typing import Any, Literal

import dask
import dask.dataframe as dd
Expand Down Expand Up @@ -80,6 +80,24 @@
_Normalize = Normalize | abc.Sequence[Normalize]


def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray:
if isinstance(element, DataTree):
return next(iter(next(iter(element.values())).data_vars.values()))
return element


def _guard_2d_only(element: xr.DataArray | DataTree, element_name: str, kind: Literal["images", "labels"]) -> None:
top = _get_top_data_array(element)
if "z" in top.dims:
z_size = top.sizes["z"]
raise ValueError(
f"render_{kind} does not support 3D {kind}. Element '{element_name}' has a 'z' dimension "
f"with {z_size} slices. Select a 2D slice before plotting:\n"
f" sdata['{element_name}'].isel(z=0)\n"
"or use sd.bounding_box_query() to extract a 2D region."
)


def _want_decorations(color_vector: Any, na_color: Color) -> bool:
"""Return whether legend/colorbar decorations should be shown.

Expand Down Expand Up @@ -1247,6 +1265,7 @@ def _render_images(

palette = render_params.palette
img = sdata_filt[render_params.element]
_guard_2d_only(img, render_params.element, "images")
extent = get_extent(img, coordinate_system=coordinate_system)
scale = render_params.scale

Expand Down Expand Up @@ -1674,6 +1693,7 @@ def _render_labels(
)

label = sdata_filt.labels[element]
_guard_2d_only(label, element, "labels")
extent = get_extent(label, coordinate_system=coordinate_system)

# get best scale out of multiscale label
Expand Down
17 changes: 16 additions & 1 deletion tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from matplotlib.colors import LogNorm, Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData
from spatialdata.models import Image2DModel
from spatialdata.models import Image2DModel, Image3DModel

import spatialdata_plot # noqa: F401
from spatialdata_plot._logging import logger, logger_warns
Expand Down Expand Up @@ -720,6 +720,21 @@ def test_channels_as_legend_coexists_with_other_elements(self, sdata_blobs: Spat
plt.close("all")


@pytest.mark.parametrize("scale_factors", [None, [2]])
def test_render_images_raises_on_3d(scale_factors):
# Regression test for #608: 3D images must raise a clear ValueError, not crash
# deep in matplotlib with "Invalid shape" / opaque numpy errors.
img = np.random.default_rng(0).random((2, 4, 16, 16), dtype=np.float32)
image3d = Image3DModel.parse(img, dims=["c", "z", "y", "x"], c_coords=["DAPI", "GFP"], scale_factors=scale_factors)
sdata = SpatialData(images={"img3d": image3d})
fig, ax = plt.subplots()
try:
with pytest.raises(ValueError, match=r"render_images does not support 3D.*img3d.*z.*4"):
sdata.pl.render_images("img3d").pl.show(ax=ax)
finally:
plt.close(fig)


def test_lognorm_with_zeros_suppresses_colorbar_with_warning():
# regression test for #604: LogNorm + non-positive data must not raise an opaque
# matplotlib ValueError; instead suppress the colorbar with an actionable UserWarning.
Expand Down
17 changes: 16 additions & 1 deletion tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from matplotlib.colors import Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata.models import Labels2DModel, TableModel
from spatialdata.models import Labels2DModel, Labels3DModel, TableModel

import spatialdata_plot # noqa: F401
from spatialdata_plot._logging import logger, logger_warns
Expand Down Expand Up @@ -550,3 +550,18 @@ def test_render_labels_disjoint_instance_ids_clear_error():
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)


@pytest.mark.parametrize("scale_factors", [None, [2]])
def test_render_labels_raises_on_3d(scale_factors):
# Regression test for #608: 3D labels must raise a clear ValueError, not crash
# deep in numpy with an opaque concatenation error.
arr = np.random.default_rng(0).integers(0, 5, size=(4, 16, 16), dtype=np.int32)
labels3d = Labels3DModel.parse(arr, dims=["z", "y", "x"], scale_factors=scale_factors)
sdata = SpatialData(labels={"lbl3d": labels3d})
fig, ax = plt.subplots()
try:
with pytest.raises(ValueError, match=r"render_labels does not support 3D.*lbl3d.*z.*4"):
sdata.pl.render_labels("lbl3d").pl.show(ax=ax)
finally:
plt.close(fig)
Loading