Skip to content

Commit 61ce419

Browse files
committed
Add method='datashader' to render_images for sparse images (#449)
Mean-aggregating rasterization + imshow interpolation collapses very sparse images (mostly zeros, rare non-zero pixels) to near-black. Adds method='datashader' + ds_reduction kwargs mirroring the existing render_points/render_shapes API; routes the downsample step through datashader.Canvas.raster with a configurable reduction (default 'max') and forces nearest-neighbor display so the reduction is not re-smoothed. Also centralizes the _DsReduction Literal (previously duplicated across five sites) into render_params.py alongside a new _ImageDsReduction for the image-only set ('mode', 'first', 'last' added; 'sum', 'any', 'count' dropped since they're not valid Canvas.raster downsample methods).
1 parent 3ebefe1 commit 61ce419

6 files changed

Lines changed: 242 additions & 14 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from matplotlib.colors import Normalize
1818

1919
from spatialdata_plot._logging import logger
20-
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams
20+
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams, _DsReduction
2121
from spatialdata_plot.pl.utils import (
2222
_ax_show_and_transform,
2323
_convert_alpha_to_datashader_range,
@@ -32,8 +32,6 @@
3232
# Type aliases and constants
3333
# ---------------------------------------------------------------------------
3434

35-
_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]
36-
3735
# Sentinel category name used in datashader categorical paths to represent
3836
# missing (NaN) values. Must not collide with realistic user category names.
3937
_DS_NAN_CATEGORY = "ds_nan"

src/spatialdata_plot/pl/basic.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Callable, Sequence
88
from copy import deepcopy
99
from pathlib import Path
10-
from typing import Any, Literal, cast
10+
from typing import Any, Literal, cast, get_args
1111

1212
import matplotlib
1313
import matplotlib.pyplot as plt
@@ -29,7 +29,7 @@
2929
from xarray import DataArray, DataTree
3030

3131
from spatialdata_plot._accessor import register_spatial_data_accessor
32-
from spatialdata_plot._logging import _log_context
32+
from spatialdata_plot._logging import _log_context, logger
3333
from spatialdata_plot.pl.render import (
3434
_draw_channel_legend,
3535
_render_graph,
@@ -52,8 +52,10 @@
5252
LegendParams,
5353
PointsRenderParams,
5454
ShapesRenderParams,
55+
_DsReduction,
5556
_FontSize,
5657
_FontWeight,
58+
_ImageDsReduction,
5759
)
5860
from spatialdata_plot.pl.utils import (
5961
_RENDER_CMD_TO_CS_FLAG,
@@ -194,7 +196,7 @@ def render_shapes(
194196
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
195197
colorbar: bool | str | None = "auto",
196198
colorbar_params: dict[str, object] | None = None,
197-
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
199+
datashader_reduction: _DsReduction | None = None,
198200
transfunc: Callable[[float], float] | None = None,
199201
) -> sd.SpatialData:
200202
"""
@@ -384,7 +386,7 @@ def render_points(
384386
gene_symbols: str | None = None,
385387
colorbar: bool | str | None = "auto",
386388
colorbar_params: dict[str, object] | None = None,
387-
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
389+
datashader_reduction: _DsReduction | None = None,
388390
transfunc: Callable[[float], float] | None = None,
389391
) -> sd.SpatialData:
390392
"""
@@ -536,6 +538,8 @@ def render_images(
536538
colorbar: bool | str | None = "auto",
537539
colorbar_params: dict[str, object] | None = None,
538540
channels_as_legend: bool = False,
541+
method: Literal["matplotlib", "datashader"] | None = None,
542+
ds_reduction: _ImageDsReduction | None = None,
539543
) -> sd.SpatialData:
540544
"""
541545
Render image elements in SpatialData.
@@ -616,6 +620,21 @@ def render_images(
616620
Ignored for single-channel and RGB(A) images. When multiple
617621
``render_images`` calls use this flag on the same axes, all
618622
channel entries are combined into a single legend.
623+
method : str | None, optional
624+
Whether to use ``'matplotlib'`` (default) or ``'datashader'`` for
625+
the downsampling step. When ``'datashader'`` is selected, the
626+
rasterization-to-canvas step uses
627+
:meth:`datashader.Canvas.raster` with ``ds_reduction`` as the
628+
downsample method (default ``'max'``), and ``imshow`` is rendered
629+
with ``interpolation='nearest'`` so the chosen reduction is not
630+
re-smoothed at display time. Useful for very sparse images
631+
(mostly zeros) where mean aggregation collapses the signal —
632+
``method='datashader'`` with ``ds_reduction='max'`` preserves the
633+
rare non-zero pixels (``plt.spy``-style).
634+
ds_reduction : {"max", "min", "mean", "mode", "first", "last", "var", "std"} | None, optional
635+
Downsample reduction used by the datashader path. Defaults to
636+
``'max'`` when ``method='datashader'``. Ignored otherwise (a
637+
warning is emitted if set without ``method='datashader'``).
619638
620639
Notes
621640
-----
@@ -634,6 +653,20 @@ def render_images(
634653
"""
635654
if grayscale and palette is not None:
636655
raise ValueError("Cannot combine grayscale=True with palette.")
656+
657+
if method is not None and not isinstance(method, str):
658+
raise TypeError("Parameter 'method' must be a string.")
659+
if method is not None and method not in ("matplotlib", "datashader"):
660+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
661+
if ds_reduction is not None and not isinstance(ds_reduction, str):
662+
raise TypeError("Parameter 'ds_reduction' must be a string.")
663+
if ds_reduction is not None and ds_reduction not in get_args(_ImageDsReduction):
664+
raise ValueError(
665+
f"Parameter 'ds_reduction' must be one of {get_args(_ImageDsReduction)}, got {ds_reduction!r}."
666+
)
667+
if ds_reduction is not None and method != "datashader":
668+
logger.warning("Parameter 'ds_reduction' has no effect unless method='datashader'; ignoring.")
669+
637670
params_dict = _validate_image_render_params(
638671
self._sdata,
639672
element=element,
@@ -699,6 +732,8 @@ def render_images(
699732
transfunc=transfunc,
700733
grayscale=grayscale,
701734
channels_as_legend=channels_as_legend,
735+
method=method,
736+
ds_reduction=ds_reduction,
702737
)
703738
n_steps += 1
704739

src/spatialdata_plot/pl/render.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
_ds_aggregate,
4040
_ds_shade_categorical,
4141
_ds_shade_continuous,
42-
_DsReduction,
4342
_render_ds_image,
4443
_render_ds_outlines,
4544
)
@@ -55,6 +54,7 @@
5554
LegendParams,
5655
PointsRenderParams,
5756
ShapesRenderParams,
57+
_DsReduction,
5858
)
5959
from spatialdata_plot.pl.utils import (
6060
_ax_show_and_transform,
@@ -73,6 +73,7 @@
7373
_prepare_cmap_norm,
7474
_prepare_transformation,
7575
_rasterize_if_necessary,
76+
_rasterize_if_necessary_datashader,
7677
_set_color_source_vec,
7778
_validate_polygons,
7879
)
@@ -1279,7 +1280,24 @@ def _render_images(
12791280
scale=scale,
12801281
)
12811282
# rasterize spatial image if necessary to speed up performance
1282-
if rasterize:
1283+
use_datashader = render_params.method == "datashader"
1284+
if use_datashader:
1285+
downsample_method = render_params.ds_reduction or "max"
1286+
logger.info(
1287+
f"Using 'datashader' backend with '{downsample_method}' as downsample method. "
1288+
"Depending on the reduction, the value range of the plot might change. "
1289+
"Set method to 'matplotlib' to disable this behaviour."
1290+
)
1291+
img = _rasterize_if_necessary_datashader(
1292+
image=img,
1293+
dpi=fig_params.fig.dpi,
1294+
width=fig_params.fig.get_size_inches()[0],
1295+
height=fig_params.fig.get_size_inches()[1],
1296+
coordinate_system=coordinate_system,
1297+
extent=extent,
1298+
downsample_method=downsample_method,
1299+
)
1300+
elif rasterize:
12831301
img = _rasterize_if_necessary(
12841302
image=img,
12851303
dpi=fig_params.fig.dpi,
@@ -1389,6 +1407,10 @@ def _render_images(
13891407
"Consider using 'palette' instead."
13901408
)
13911409

1410+
# Force nearest-neighbor at display time when the datashader reduction picked
1411+
# a non-mean aggregation; otherwise imshow's default interpolation would smear it.
1412+
_interp = "nearest" if use_datashader else None
1413+
13921414
# Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
13931415
is_rgb, has_alpha = _is_rgb_image(channels)
13941416
has_explicit_cmap = (
@@ -1430,7 +1452,7 @@ def _render_images(
14301452
render_params.alpha,
14311453
)
14321454

1433-
_ax_show_and_transform(stacked, trans_data, ax, **show_kwargs)
1455+
_ax_show_and_transform(stacked, trans_data, ax, interpolation=_interp, **show_kwargs)
14341456
if render_params.channels_as_legend:
14351457
logger.warning("channels_as_legend is not supported for true RGB images and will be ignored.")
14361458
return
@@ -1457,6 +1479,7 @@ def _render_images(
14571479
cmap=cmap,
14581480
zorder=render_params.zorder,
14591481
norm=render_params.cmap_params.norm,
1482+
interpolation=_interp,
14601483
)
14611484

14621485
wants_colorbar = _should_request_colorbar(
@@ -1549,6 +1572,7 @@ def _render_images(
15491572
ax,
15501573
render_params.alpha,
15511574
zorder=render_params.zorder,
1575+
interpolation=_interp,
15521576
)
15531577

15541578
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
@@ -1613,6 +1637,7 @@ def _render_images(
16131637
ax,
16141638
render_params.alpha,
16151639
zorder=render_params.zorder,
1640+
interpolation=_interp,
16161641
)
16171642

16181643
# 2C) palette set; also covers `palette + norm=list` since synthesized
@@ -1633,6 +1658,7 @@ def _render_images(
16331658
ax,
16341659
render_params.alpha,
16351660
zorder=render_params.zorder,
1661+
interpolation=_interp,
16361662
)
16371663

16381664
elif palette is None and got_multiple_cmaps:
@@ -1654,6 +1680,7 @@ def _render_images(
16541680
ax,
16551681
render_params.alpha,
16561682
zorder=render_params.zorder,
1683+
interpolation=_interp,
16571684
)
16581685

16591686
# Collect channel legend entries (single point for all multi-channel paths)

src/spatialdata_plot/pl/render_params.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
1414
_FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
15+
_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]
16+
_ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"]
1517

1618
# replace with
1719
# from spatialdata._types import ColorLike
@@ -243,7 +245,7 @@ class ShapesRenderParams:
243245
table_name: str | None = None
244246
table_layer: str | None = None
245247
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
246-
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
248+
ds_reduction: _DsReduction | None = None
247249
colorbar: bool | str | None = "auto"
248250
colorbar_params: dict[str, object] | None = None
249251

@@ -265,7 +267,7 @@ class PointsRenderParams:
265267
zorder: int = 0
266268
table_name: str | None = None
267269
table_layer: str | None = None
268-
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
270+
ds_reduction: _DsReduction | None = None
269271
colorbar: bool | str | None = "auto"
270272
colorbar_params: dict[str, object] | None = None
271273

@@ -286,6 +288,8 @@ class ImageRenderParams:
286288
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
287289
grayscale: bool = False
288290
channels_as_legend: bool = False
291+
method: Literal["matplotlib", "datashader"] | None = None
292+
ds_reduction: _ImageDsReduction | None = None
289293

290294

291295
@dataclass

src/spatialdata_plot/pl/utils.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
PointsRenderParams,
8484
ScalebarParams,
8585
ShapesRenderParams,
86+
_DsReduction,
8687
_FontSize,
8788
_FontWeight,
8889
)
@@ -2048,6 +2049,58 @@ def _rasterize_if_necessary(
20482049
return image
20492050

20502051

2052+
def _rasterize_if_necessary_datashader(
2053+
image: DataArray,
2054+
dpi: float,
2055+
width: float,
2056+
height: float,
2057+
coordinate_system: str,
2058+
extent: dict[str, tuple[float, float]],
2059+
downsample_method: str,
2060+
) -> DataArray:
2061+
"""Downsample to canvas resolution with a configurable datashader reduction.
2062+
2063+
Used by ``render_images(method='datashader')`` so sparse images (mostly
2064+
zeros, rare non-zero pixels) survive the downsample step instead of
2065+
being averaged away by the default mean aggregation.
2066+
"""
2067+
has_c_dim = len(image.shape) == 3
2068+
y_dims, x_dims = (image.shape[1], image.shape[2]) if has_c_dim else image.shape
2069+
2070+
target_y_dims = int(dpi * height)
2071+
target_x_dims = int(dpi * width)
2072+
2073+
if y_dims <= target_y_dims and x_dims <= target_x_dims:
2074+
return image
2075+
2076+
# spatialdata.rasterize is invoked solely to inherit the output coords and
2077+
# spatial transformation; its mean-aggregated values are overwritten below.
2078+
world_x = float(extent["x"][1]) - float(extent["x"][0])
2079+
world_y = float(extent["y"][1]) - float(extent["y"][0])
2080+
target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x)
2081+
base = rasterize(
2082+
image,
2083+
("y", "x"),
2084+
[extent["y"][0], extent["x"][0]],
2085+
[extent["y"][1], extent["x"][1]],
2086+
coordinate_system,
2087+
target_unit_to_pixels=target_unit_to_pixels,
2088+
)
2089+
2090+
out_y, out_x = (base.shape[1], base.shape[2]) if has_c_dim else base.shape
2091+
# Materialize once: per-chunk reductions across channels would otherwise
2092+
# trigger repeated dask graph evaluations on the same source array.
2093+
src = image.compute() if hasattr(image.data, "compute") else image
2094+
cvs = ds.Canvas(
2095+
plot_width=out_x,
2096+
plot_height=out_y,
2097+
x_range=(float(extent["x"][0]), float(extent["x"][1])),
2098+
y_range=(float(extent["y"][0]), float(extent["y"][1])),
2099+
)
2100+
base.values = np.asarray(cvs.raster(src, downsample_method=downsample_method).values).astype(base.dtype, copy=False)
2101+
return base
2102+
2103+
20512104
def _multiscale_to_spatial_image(
20522105
multiscale_image: DataTree,
20532106
dpi: float,
@@ -3385,6 +3438,7 @@ def _ax_show_and_transform(
33853438
cmap: ListedColormap | LinearSegmentedColormap | None = None,
33863439
zorder: int = 0,
33873440
norm: Normalize | None = None,
3441+
interpolation: str | None = None,
33883442
) -> matplotlib.image.AxesImage:
33893443
# ``extent`` uses mpl's pixel-grid convention; world placement happens via
33903444
# ``set_transform(trans_data)`` afterwards.
@@ -3396,6 +3450,8 @@ def _ax_show_and_transform(
33963450
imshow_kwargs["alpha"] = alpha
33973451
else:
33983452
imshow_kwargs["cmap"] = cmap
3453+
if interpolation is not None:
3454+
imshow_kwargs["interpolation"] = interpolation
33993455
im = ax.imshow(array, **imshow_kwargs)
34003456
im.set_transform(trans_data)
34013457
return im
@@ -3508,7 +3564,7 @@ def _create_image_from_datashader_result(
35083564

35093565

35103566
def _datashader_aggregate_with_function(
3511-
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
3567+
reduction: _DsReduction | None,
35123568
cvs: Canvas,
35133569
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
35143570
col_for_color: str | None,
@@ -3572,7 +3628,7 @@ def _datashader_aggregate_with_function(
35723628

35733629

35743630
def _datshader_get_how_kw_for_spread(
3575-
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
3631+
reduction: _DsReduction | None,
35763632
) -> str:
35773633
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
35783634
reduction = reduction or "sum"

0 commit comments

Comments
 (0)