Skip to content

Commit c43161d

Browse files
timtreisclaude
andcommitted
Add channels_as_categories parameter to render_images (#459)
Add a `channels_as_categories: bool = False` parameter to `render_images()` that shows a categorical legend mapping each channel name to its compositing color for multi-channel images. Legend entries are accumulated across chained `render_images()` calls and drawn as a single combined legend after the render loop, following the same deferred pattern as colorbars. When labels or shapes also contribute legends on the same axes, the channel entries merge automatically via matplotlib's artist-based legend collection. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 303140c commit c43161d

9 files changed

Lines changed: 209 additions & 0 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from spatialdata_plot._accessor import register_spatial_data_accessor
3131
from spatialdata_plot._logging import _log_context, logger
3232
from spatialdata_plot.pl.render import (
33+
_draw_channel_legend,
3334
_render_images,
3435
_render_labels,
3536
_render_points,
@@ -40,6 +41,7 @@
4041
CBAR_DEFAULT_FRACTION,
4142
CBAR_DEFAULT_LOCATION,
4243
CBAR_DEFAULT_PAD,
44+
ChannelLegendEntry,
4345
CmapParams,
4446
ColorbarSpec,
4547
ImageRenderParams,
@@ -523,6 +525,7 @@ def render_images(
523525
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None,
524526
colorbar: bool | str | None = "auto",
525527
colorbar_params: dict[str, object] | None = None,
528+
channels_as_categories: bool = False,
526529
**kwargs: Any,
527530
) -> sd.SpatialData:
528531
"""
@@ -600,6 +603,13 @@ def render_images(
600603
colorbar_params :
601604
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
602605
and ``label``.
606+
channels_as_categories : bool, default False
607+
When ``True`` and rendering multiple channels, show a categorical
608+
legend mapping each channel name to its compositing color. The
609+
legend uses the ``legend_*`` parameters from :meth:`show`.
610+
Ignored for single-channel and RGB(A) images. When multiple
611+
``render_images`` calls use this flag on the same axes, all
612+
channel entries are combined into a single legend.
603613
kwargs
604614
Additional arguments to be passed to cmap, norm, and other rendering functions.
605615
@@ -681,6 +691,7 @@ def render_images(
681691
colorbar_params=param_values["colorbar_params"],
682692
transfunc=transfunc,
683693
grayscale=grayscale,
694+
channels_as_categories=channels_as_categories,
684695
)
685696
n_steps += 1
686697

@@ -1140,6 +1151,7 @@ def _draw_colorbar(
11401151
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
11411152
assert isinstance(ax, Axes)
11421153
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
1154+
axis_channel_legend_entries: list[ChannelLegendEntry] = []
11431155

11441156
wants_images = False
11451157
wants_labels = False
@@ -1170,6 +1182,7 @@ def _draw_colorbar(
11701182
scalebar_params=scalebar_params,
11711183
legend_params=legend_params,
11721184
colorbar_requests=axis_colorbar_requests,
1185+
channel_legend_entries=axis_channel_legend_entries,
11731186
rasterize=rasterize,
11741187
)
11751188

@@ -1279,6 +1292,9 @@ def _draw_colorbar(
12791292
if legend_params.colorbar and axis_colorbar_requests:
12801293
pending_colorbars.append((ax, axis_colorbar_requests))
12811294

1295+
if axis_channel_legend_entries:
1296+
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params)
1297+
12821298
if pending_colorbars and fig_params.fig is not None:
12831299
fig = fig_params.fig
12841300
fig.canvas.draw()

src/spatialdata_plot/pl/render.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
from collections import abc
5+
from collections.abc import Sequence
56
from copy import copy
67
from typing import Any
78

@@ -18,9 +19,11 @@
1819
import spatialdata as sd
1920
import xarray as xr
2021
from anndata import AnnData
22+
from matplotlib import patheffects
2123
from matplotlib.cm import ScalarMappable
2224
from matplotlib.colors import ListedColormap, Normalize
2325
from scanpy._settings import settings as sc_settings
26+
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
2427
from spatialdata import get_extent, get_values, join_spatialelement_table
2528
from spatialdata._core.query.relational_query import match_table_to_element
2629
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
@@ -41,6 +44,7 @@
4144
_render_ds_outlines,
4245
)
4346
from spatialdata_plot.pl.render_params import (
47+
ChannelLegendEntry,
4448
CmapParams,
4549
Color,
4650
ColorbarSpec,
@@ -1094,6 +1098,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
10941098
return False, False
10951099

10961100

1101+
def _collect_channel_legend_entries(
1102+
channels: Sequence[str | int],
1103+
seed_colors: Sequence[str | tuple[float, ...]],
1104+
channel_legend_entries: list[ChannelLegendEntry],
1105+
) -> None:
1106+
"""Accumulate channel-to-color mappings for a deferred combined legend."""
1107+
channel_names = [str(ch) for ch in channels]
1108+
if len(set(channel_names)) != len(channel_names):
1109+
logger.warning("channels_as_categories: duplicate channel names detected; skipping legend entries.")
1110+
return
1111+
1112+
color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors]
1113+
for name, color in zip(channel_names, color_hexes, strict=True):
1114+
channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color))
1115+
1116+
1117+
def _draw_channel_legend(
1118+
ax: matplotlib.axes.SubplotBase,
1119+
entries: list[ChannelLegendEntry],
1120+
legend_params: LegendParams,
1121+
fig_params: FigParams,
1122+
) -> None:
1123+
"""Draw a single combined categorical legend from accumulated channel entries.
1124+
1125+
Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
1126+
calling it here automatically merges with any earlier legend entries
1127+
(e.g. from labels or shapes) on the same axes via ``ax.legend()``.
1128+
1129+
``multi_panel`` is only set when no prior legend exists on the axis,
1130+
to avoid shrinking the axes twice (once for labels/shapes, once for
1131+
channels).
1132+
"""
1133+
# Deduplicate: if the same channel name appears twice, keep the last color
1134+
palette_dict: dict[str, str] = {}
1135+
for entry in entries:
1136+
palette_dict[entry.channel_name] = entry.color_hex
1137+
1138+
legend_loc = legend_params.legend_loc
1139+
if legend_loc == "on data":
1140+
logger.warning(
1141+
"legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
1142+
"falling back to 'right margin'."
1143+
)
1144+
legend_loc = "right margin"
1145+
1146+
categories = pd.Categorical(list(palette_dict))
1147+
1148+
path_effect = (
1149+
[patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")]
1150+
if legend_params.legend_fontoutline is not None
1151+
else []
1152+
)
1153+
1154+
# Only apply multi_panel shrink if no legend already exists on this axis
1155+
# (labels/shapes draw their legend during the render loop and already shrink).
1156+
has_existing_legend = ax.get_legend() is not None
1157+
needs_multi_panel = fig_params.axs is not None and not has_existing_legend
1158+
1159+
_add_categorical_legend(
1160+
ax,
1161+
categories,
1162+
palette=palette_dict,
1163+
legend_loc=legend_loc,
1164+
legend_fontweight=legend_params.legend_fontweight,
1165+
legend_fontsize=legend_params.legend_fontsize,
1166+
legend_fontoutline=path_effect,
1167+
na_color=["lightgray"],
1168+
na_in_legend=False,
1169+
multi_panel=needs_multi_panel,
1170+
)
1171+
1172+
10971173
def _render_images(
10981174
sdata: sd.SpatialData,
10991175
render_params: ImageRenderParams,
@@ -1104,6 +1180,7 @@ def _render_images(
11041180
legend_params: LegendParams,
11051181
rasterize: bool,
11061182
colorbar_requests: list[ColorbarSpec] | None = None,
1183+
channel_legend_entries: list[ChannelLegendEntry] | None = None,
11071184
) -> None:
11081185
_log_context.set("render_images")
11091186
sdata_filt = sdata.filter_by_coordinate_system(
@@ -1319,10 +1396,14 @@ def _render_images(
13191396

13201397
layers[ch] = ch_norm(layers[ch])
13211398

1399+
# Colors for the channel legend (set by each branch if applicable)
1400+
legend_colors: list[str] | None = None
1401+
13221402
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
13231403
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
13241404
if render_params.cmap_params.cmap_is_default: # -> use RGB
13251405
stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1)
1406+
legend_colors = ["red", "green", "blue"]
13261407
else: # -> use given cmap for each channel
13271408
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
13281409
stacked = (
@@ -1404,6 +1485,8 @@ def _render_images(
14041485
f"multichannel strategy 'stack' to render."
14051486
) # TODO: update when pca is added as strategy
14061487

1488+
legend_colors = seed_colors
1489+
14071490
_ax_show_and_transform(
14081491
colored,
14091492
trans_data,
@@ -1421,6 +1504,8 @@ def _render_images(
14211504
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
14221505
colored = np.clip(colored[:, :, :3], 0, 1)
14231506

1507+
legend_colors = list(palette)
1508+
14241509
_ax_show_and_transform(
14251510
colored,
14261511
trans_data,
@@ -1440,6 +1525,8 @@ def _render_images(
14401525
)
14411526
colored = colored[:, :, :3]
14421527

1528+
legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps]
1529+
14431530
_ax_show_and_transform(
14441531
colored,
14451532
trans_data,
@@ -1452,6 +1539,17 @@ def _render_images(
14521539
elif palette is not None and got_multiple_cmaps:
14531540
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
14541541

1542+
# Collect channel legend entries (single point for all multi-channel paths)
1543+
if render_params.channels_as_categories and channel_legend_entries is not None:
1544+
if legend_colors is not None:
1545+
_collect_channel_legend_entries(channels, legend_colors, channel_legend_entries)
1546+
else:
1547+
logger.warning(
1548+
"channels_as_categories requires distinct per-channel colors; "
1549+
"ignored when a single cmap is shared across channels. "
1550+
"Use 'palette' or a list of cmaps instead."
1551+
)
1552+
14551553

14561554
def _render_labels(
14571555
sdata: sd.SpatialData,

src/spatialdata_plot/pl/render_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ class ColorbarSpec:
199199
alpha: float | None = None
200200

201201

202+
@dataclass
203+
class ChannelLegendEntry:
204+
"""A single channel-to-color mapping for the categorical channel legend."""
205+
206+
channel_name: str
207+
color_hex: str
208+
209+
202210
CBAR_DEFAULT_LOCATION = "right"
203211
CBAR_DEFAULT_FRACTION = 0.075
204212
CBAR_DEFAULT_PAD = 0.015
@@ -275,6 +283,7 @@ class ImageRenderParams:
275283
colorbar_params: dict[str, object] | None = None
276284
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
277285
grayscale: bool = False
286+
channels_as_categories: bool = False
278287

279288

280289
@dataclass
109 KB
Loading
95.9 KB
Loading
70.3 KB
Loading
81.1 KB
Loading
91.1 KB
Loading

tests/pl/test_render_images.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,89 @@ def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData)
491491
sdata_blobs.pl.render_images("blobs_image", channel=[0], cmap=["gray"]).pl.show(ax=ax)
492492
assert len(ax.get_images()) == 1
493493
plt.close(fig)
494+
495+
496+
# ---------------------------------------------------------------------------
497+
# channels_as_categories visual tests (#459)
498+
# ---------------------------------------------------------------------------
499+
500+
501+
class TestChannelsAsCategories(PlotTester, metaclass=PlotTesterMeta):
502+
def test_plot_channels_as_categories_two_channels(self, sdata_blobs: SpatialData):
503+
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show()
504+
505+
def test_plot_channels_as_categories_three_channels_default(self, sdata_blobs: SpatialData):
506+
sdata_blobs.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show()
507+
508+
def test_plot_channels_as_categories_with_palette(self, sdata_blobs_str: SpatialData):
509+
sdata_blobs_str.pl.render_images(
510+
element="blobs_image",
511+
channel=["c1", "c2", "c3"],
512+
palette=["red", "green", "blue"],
513+
channels_as_categories=True,
514+
).pl.show()
515+
516+
def test_plot_channels_as_categories_many_channels(self, sdata_blobs_str: SpatialData):
517+
sdata_blobs_str.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show()
518+
519+
def test_plot_channels_as_categories_with_cmap_list(self, sdata_blobs: SpatialData):
520+
sdata_blobs.pl.render_images(
521+
element="blobs_image",
522+
channel=[0, 1, 2],
523+
cmap=["Reds", "Greens", "Blues"],
524+
channels_as_categories=True,
525+
).pl.show()
526+
527+
528+
class TestChannelsAsCategoriesNonVisual:
529+
"""Non-visual tests for channels_as_categories edge cases."""
530+
531+
def test_channels_as_categories_ignored_for_single_channel(self, sdata_blobs: SpatialData):
532+
fig, ax = plt.subplots()
533+
sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_categories=True).pl.show(ax=ax)
534+
assert ax.get_legend() is None
535+
plt.close("all")
536+
537+
def test_channels_as_categories_false_no_legend(self, sdata_blobs: SpatialData):
538+
fig, ax = plt.subplots()
539+
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=False).pl.show(ax=ax)
540+
assert ax.get_legend() is None
541+
plt.close("all")
542+
543+
def test_channels_as_categories_chained_renders_combine(self, sdata_blobs: SpatialData):
544+
"""Multiple render_images with channels_as_categories should produce one combined legend."""
545+
fig, ax = plt.subplots()
546+
(
547+
sdata_blobs.pl.render_images(
548+
element="blobs_image", channel=[0, 1], palette=["red", "green"], channels_as_categories=True
549+
)
550+
.pl.render_images(
551+
element="blobs_image", channel=[1, 2], palette=["cyan", "blue"], channels_as_categories=True
552+
)
553+
.pl.show(ax=ax)
554+
)
555+
legend = ax.get_legend()
556+
assert legend is not None
557+
labels = [t.get_text() for t in legend.get_texts()]
558+
# Both render calls contribute: channels 0, 1, 2.
559+
# Channel "1" appears in both calls — dedup keeps the last color.
560+
assert "0" in labels
561+
assert "1" in labels
562+
assert "2" in labels
563+
assert len(labels) == 3
564+
plt.close("all")
565+
566+
def test_channels_as_categories_coexists_with_other_elements(self, sdata_blobs: SpatialData):
567+
"""Channel legend should not crash when combined with other render calls."""
568+
fig, ax = plt.subplots()
569+
(
570+
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True)
571+
.pl.render_labels(element="blobs_labels")
572+
.pl.show(ax=ax)
573+
)
574+
legend = ax.get_legend()
575+
assert legend is not None
576+
labels = [t.get_text() for t in legend.get_texts()]
577+
assert "0" in labels
578+
assert "1" in labels
579+
plt.close("all")

0 commit comments

Comments
 (0)