22
33import dataclasses
44from collections import abc
5+ from collections .abc import Sequence
56from copy import copy
67from typing import Any
78
1819import spatialdata as sd
1920import xarray as xr
2021from anndata import AnnData
22+ from matplotlib import patheffects
2123from matplotlib .cm import ScalarMappable
2224from matplotlib .colors import ListedColormap , Normalize
2325from scanpy ._settings import settings as sc_settings
26+ from scanpy .plotting ._tools .scatterplots import _add_categorical_legend
2427from spatialdata import get_extent , get_values , join_spatialelement_table
2528from spatialdata ._core .query .relational_query import match_table_to_element
2629from spatialdata .models import PointsModel , ShapesModel , get_table_keys
4144 _render_ds_outlines ,
4245)
4346from spatialdata_plot .pl .render_params import (
47+ ChannelLegendEntry ,
4448 CmapParams ,
4549 Color ,
4650 ColorbarSpec ,
@@ -1094,6 +1098,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
10941098 return False , False
10951099
10961100
1101+ def _collect_channel_legend_entries (
1102+ channels : Sequence [str | int ],
1103+ seed_colors : Sequence [str | tuple [float , ...]],
1104+ channel_legend_entries : list [ChannelLegendEntry ],
1105+ ) -> None :
1106+ """Accumulate channel-to-color mappings for a deferred combined legend."""
1107+ channel_names = [str (ch ) for ch in channels ]
1108+ if len (set (channel_names )) != len (channel_names ):
1109+ logger .warning ("channels_as_categories: duplicate channel names detected; skipping legend entries." )
1110+ return
1111+
1112+ color_hexes = [matplotlib .colors .to_hex (c , keep_alpha = False ) for c in seed_colors ]
1113+ for name , color in zip (channel_names , color_hexes , strict = True ):
1114+ channel_legend_entries .append (ChannelLegendEntry (channel_name = name , color_hex = color ))
1115+
1116+
1117+ def _draw_channel_legend (
1118+ ax : matplotlib .axes .SubplotBase ,
1119+ entries : list [ChannelLegendEntry ],
1120+ legend_params : LegendParams ,
1121+ fig_params : FigParams ,
1122+ ) -> None :
1123+ """Draw a single combined categorical legend from accumulated channel entries.
1124+
1125+ Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
1126+ calling it here automatically merges with any earlier legend entries
1127+ (e.g. from labels or shapes) on the same axes via ``ax.legend()``.
1128+
1129+ ``multi_panel`` is only set when no prior legend exists on the axis,
1130+ to avoid shrinking the axes twice (once for labels/shapes, once for
1131+ channels).
1132+ """
1133+ # Deduplicate: if the same channel name appears twice, keep the last color
1134+ palette_dict : dict [str , str ] = {}
1135+ for entry in entries :
1136+ palette_dict [entry .channel_name ] = entry .color_hex
1137+
1138+ legend_loc = legend_params .legend_loc
1139+ if legend_loc == "on data" :
1140+ logger .warning (
1141+ "legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
1142+ "falling back to 'right margin'."
1143+ )
1144+ legend_loc = "right margin"
1145+
1146+ categories = pd .Categorical (list (palette_dict ))
1147+
1148+ path_effect = (
1149+ [patheffects .withStroke (linewidth = legend_params .legend_fontoutline , foreground = "w" )]
1150+ if legend_params .legend_fontoutline is not None
1151+ else []
1152+ )
1153+
1154+ # Only apply multi_panel shrink if no legend already exists on this axis
1155+ # (labels/shapes draw their legend during the render loop and already shrink).
1156+ has_existing_legend = ax .get_legend () is not None
1157+ needs_multi_panel = fig_params .axs is not None and not has_existing_legend
1158+
1159+ _add_categorical_legend (
1160+ ax ,
1161+ categories ,
1162+ palette = palette_dict ,
1163+ legend_loc = legend_loc ,
1164+ legend_fontweight = legend_params .legend_fontweight ,
1165+ legend_fontsize = legend_params .legend_fontsize ,
1166+ legend_fontoutline = path_effect ,
1167+ na_color = ["lightgray" ],
1168+ na_in_legend = False ,
1169+ multi_panel = needs_multi_panel ,
1170+ )
1171+
1172+
10971173def _render_images (
10981174 sdata : sd .SpatialData ,
10991175 render_params : ImageRenderParams ,
@@ -1104,6 +1180,7 @@ def _render_images(
11041180 legend_params : LegendParams ,
11051181 rasterize : bool ,
11061182 colorbar_requests : list [ColorbarSpec ] | None = None ,
1183+ channel_legend_entries : list [ChannelLegendEntry ] | None = None ,
11071184) -> None :
11081185 _log_context .set ("render_images" )
11091186 sdata_filt = sdata .filter_by_coordinate_system (
@@ -1319,10 +1396,14 @@ def _render_images(
13191396
13201397 layers [ch ] = ch_norm (layers [ch ])
13211398
1399+ # Colors for the channel legend (set by each branch if applicable)
1400+ legend_colors : list [str ] | None = None
1401+
13221402 # 2A) Image has 3 channels, no palette info, and no/only one cmap was given
13231403 if palette is None and n_channels == 3 and not isinstance (render_params .cmap_params , list ):
13241404 if render_params .cmap_params .cmap_is_default : # -> use RGB
13251405 stacked = np .clip (np .stack ([layers [ch ] for ch in layers ], axis = - 1 ), 0 , 1 )
1406+ legend_colors = ["red" , "green" , "blue" ]
13261407 else : # -> use given cmap for each channel
13271408 channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
13281409 stacked = (
@@ -1404,6 +1485,8 @@ def _render_images(
14041485 f"multichannel strategy 'stack' to render."
14051486 ) # TODO: update when pca is added as strategy
14061487
1488+ legend_colors = seed_colors
1489+
14071490 _ax_show_and_transform (
14081491 colored ,
14091492 trans_data ,
@@ -1421,6 +1504,8 @@ def _render_images(
14211504 colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 )
14221505 colored = np .clip (colored [:, :, :3 ], 0 , 1 )
14231506
1507+ legend_colors = list (palette )
1508+
14241509 _ax_show_and_transform (
14251510 colored ,
14261511 trans_data ,
@@ -1440,6 +1525,8 @@ def _render_images(
14401525 )
14411526 colored = colored [:, :, :3 ]
14421527
1528+ legend_colors = [matplotlib .colors .to_hex (cm (0.75 )) for cm in channel_cmaps ]
1529+
14431530 _ax_show_and_transform (
14441531 colored ,
14451532 trans_data ,
@@ -1452,6 +1539,17 @@ def _render_images(
14521539 elif palette is not None and got_multiple_cmaps :
14531540 raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
14541541
1542+ # Collect channel legend entries (single point for all multi-channel paths)
1543+ if render_params .channels_as_categories and channel_legend_entries is not None :
1544+ if legend_colors is not None :
1545+ _collect_channel_legend_entries (channels , legend_colors , channel_legend_entries )
1546+ else :
1547+ logger .warning (
1548+ "channels_as_categories requires distinct per-channel colors; "
1549+ "ignored when a single cmap is shared across channels. "
1550+ "Use 'palette' or a list of cmaps instead."
1551+ )
1552+
14551553
14561554def _render_labels (
14571555 sdata : sd .SpatialData ,
0 commit comments