Skip to content

Commit c764384

Browse files
timtreisclaude
andcommitted
Add per-channel norm support for render_images (#460)
Accept a list of Normalize objects in render_images so each channel can be normalized independently — essential for multi-channel protein data with vastly different intensity ranges. The rendering pipeline already reads per-channel norms from CmapParams, so this change only widens the input validation and routes per-channel norms into the existing CmapParams creation loop. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 367204d commit c764384

3 files changed

Lines changed: 80 additions & 7 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def render_images(
513513
*,
514514
channel: list[str] | list[int] | str | int | None = None,
515515
cmap: list[Colormap | str] | Colormap | str | None = None,
516-
norm: Normalize | None = None,
516+
norm: list[Normalize] | Normalize | None = None,
517517
na_color: ColorLike | None = "default",
518518
palette: list[str] | str | None = None,
519519
alpha: float | int = 1.0,
@@ -544,9 +544,11 @@ def render_images(
544544
cmap : list[Colormap | str] | Colormap | str | None
545545
Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`.
546546
Each colormap applies to a corresponding channel.
547-
norm : Normalize | None, optional
547+
norm : list[Normalize] | Normalize | None, optional
548548
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
549-
Applies to all channels if set.
549+
A single :class:`~matplotlib.colors.Normalize` applies to all channels.
550+
A list of :class:`~matplotlib.colors.Normalize` objects applies per-channel
551+
(length must match the number of colormaps/channels).
550552
na_color : ColorLike | None, default "default" (gets set to "lightgray")
551553
Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation
552554
("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values
@@ -631,13 +633,21 @@ def render_images(
631633
for element, param_values in params_dict.items():
632634
cmap_params: list[CmapParams] | CmapParams
633635
if isinstance(cmap, list):
636+
if isinstance(norm, list):
637+
if len(norm) != len(cmap):
638+
raise ValueError(
639+
f"Length of 'norm' list ({len(norm)}) must match the number of colormaps ({len(cmap)})."
640+
)
641+
norms = norm
642+
else:
643+
norms = [norm] * len(cmap)
634644
cmap_params = [
635645
_prepare_cmap_norm(
636646
cmap=c,
637-
norm=norm,
647+
norm=n,
638648
na_color=param_values["na_color"],
639649
)
640-
for c in cmap
650+
for c, n in zip(cmap, norms, strict=True)
641651
]
642652

643653
else:

src/spatialdata_plot/pl/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,8 +2409,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
24092409

24102410
norm = param_dict.get("norm")
24112411
if norm is not None:
2412-
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
2413-
raise TypeError("Parameter 'norm' must be of type Normalize.")
2412+
if element_type in {"images", "labels"}:
2413+
if isinstance(norm, list):
2414+
if not norm:
2415+
raise ValueError("Parameter 'norm' list must not be empty.")
2416+
if not all(isinstance(n, Normalize) for n in norm):
2417+
raise TypeError("Every item in 'norm' list must be a Normalize instance.")
2418+
elif not isinstance(norm, Normalize):
2419+
raise TypeError("Parameter 'norm' must be a Normalize or a list of Normalize instances.")
24142420
if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize):
24152421
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")
24162422

tests/pl/test_render_images.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,60 @@ def test_no_clipping_warning_palette_compositing(self):
416416
plt.close("all")
417417
clip_warns = [x for x in w if "Clipping input data" in str(x.message)]
418418
assert len(clip_warns) == 0, f"Got unexpected clipping warning: {clip_warns[0].message}"
419+
420+
421+
def _make_multichannel_sdata():
422+
"""Create a 3-channel image with different intensity ranges."""
423+
rng = np.random.default_rng(42)
424+
data = np.stack(
425+
[
426+
rng.uniform(0, 0.05, (50, 50)), # dim
427+
rng.uniform(0, 1.0, (50, 50)), # full range
428+
rng.uniform(0, 0.5, (50, 50)), # medium
429+
],
430+
axis=0,
431+
).astype(np.float32)
432+
img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=[0, 1, 2])
433+
return SpatialData(images={"img": img})
434+
435+
436+
def test_per_channel_norm_list():
437+
"""Per-channel norm list is accepted and renders without error (#460)."""
438+
sdata = _make_multichannel_sdata()
439+
norms = [
440+
Normalize(vmin=0, vmax=0.05, clip=True),
441+
Normalize(vmin=0, vmax=1.0, clip=True),
442+
Normalize(vmin=0, vmax=0.5, clip=True),
443+
]
444+
fig, ax = plt.subplots()
445+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=norms, cmap=[plt.cm.gray] * 3).pl.show(ax=ax)
446+
plt.close(fig)
447+
448+
449+
def test_single_norm_with_multiple_channels():
450+
"""A single Normalize shared across channels still works."""
451+
sdata = _make_multichannel_sdata()
452+
fig, ax = plt.subplots()
453+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=Normalize(0, 1), cmap=[plt.cm.gray] * 3).pl.show(ax=ax)
454+
plt.close(fig)
455+
456+
457+
def test_norm_list_length_mismatch_raises():
458+
"""Norm list length must match cmap list length."""
459+
sdata = _make_multichannel_sdata()
460+
with pytest.raises(ValueError, match="must match"):
461+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=[Normalize(0, 1)] * 2, cmap=[plt.cm.gray] * 3).pl.show()
462+
463+
464+
def test_norm_list_empty_raises():
465+
"""Empty norm list is rejected."""
466+
sdata = _make_multichannel_sdata()
467+
with pytest.raises(ValueError, match="must not be empty"):
468+
sdata.pl.render_images("img", norm=[]).pl.show()
469+
470+
471+
def test_norm_list_with_invalid_element_raises():
472+
"""Non-Normalize items in norm list are rejected."""
473+
sdata = _make_multichannel_sdata()
474+
with pytest.raises(TypeError, match="Normalize instance"):
475+
sdata.pl.render_images("img", norm=["not_a_norm"]).pl.show()

0 commit comments

Comments
 (0)