77from collections .abc import Callable , Sequence
88from copy import deepcopy
99from pathlib import Path
10- from typing import Any , Literal , cast
10+ from typing import Any , Literal , cast , get_args
1111
1212import matplotlib
1313import matplotlib .pyplot as plt
2929from xarray import DataArray , DataTree
3030
3131from spatialdata_plot ._accessor import register_spatial_data_accessor
32- from spatialdata_plot ._logging import _log_context
32+ from spatialdata_plot ._logging import _log_context , logger
3333from spatialdata_plot .pl .render import (
3434 _draw_channel_legend ,
3535 _render_graph ,
5252 LegendParams ,
5353 PointsRenderParams ,
5454 ShapesRenderParams ,
55+ _DsReduction ,
5556 _FontSize ,
5657 _FontWeight ,
58+ _ImageDsReduction ,
5759)
5860from spatialdata_plot .pl .utils import (
5961 _RENDER_CMD_TO_CS_FLAG ,
@@ -194,7 +196,7 @@ def render_shapes(
194196 shape : Literal ["circle" , "hex" , "visium_hex" , "square" ] | None = None ,
195197 colorbar : bool | str | None = "auto" ,
196198 colorbar_params : dict [str , object ] | None = None ,
197- datashader_reduction : Literal [ "sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None = None ,
199+ datashader_reduction : _DsReduction | None = None ,
198200 transfunc : Callable [[float ], float ] | None = None ,
199201 ) -> sd .SpatialData :
200202 """
@@ -384,7 +386,7 @@ def render_points(
384386 gene_symbols : str | None = None ,
385387 colorbar : bool | str | None = "auto" ,
386388 colorbar_params : dict [str , object ] | None = None ,
387- datashader_reduction : Literal [ "sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None = None ,
389+ datashader_reduction : _DsReduction | None = None ,
388390 transfunc : Callable [[float ], float ] | None = None ,
389391 ) -> sd .SpatialData :
390392 """
@@ -536,6 +538,8 @@ def render_images(
536538 colorbar : bool | str | None = "auto" ,
537539 colorbar_params : dict [str , object ] | None = None ,
538540 channels_as_legend : bool = False ,
541+ method : Literal ["matplotlib" , "datashader" ] | None = None ,
542+ ds_reduction : _ImageDsReduction | None = None ,
539543 ) -> sd .SpatialData :
540544 """
541545 Render image elements in SpatialData.
@@ -616,6 +620,21 @@ def render_images(
616620 Ignored for single-channel and RGB(A) images. When multiple
617621 ``render_images`` calls use this flag on the same axes, all
618622 channel entries are combined into a single legend.
623+ method : str | None, optional
624+ Whether to use ``'matplotlib'`` (default) or ``'datashader'`` for
625+ the downsampling step. When ``'datashader'`` is selected, the
626+ rasterization-to-canvas step uses
627+ :meth:`datashader.Canvas.raster` with ``ds_reduction`` as the
628+ downsample method (default ``'max'``), and ``imshow`` is rendered
629+ with ``interpolation='nearest'`` so the chosen reduction is not
630+ re-smoothed at display time. Useful for very sparse images
631+ (mostly zeros) where mean aggregation collapses the signal —
632+ ``method='datashader'`` with ``ds_reduction='max'`` preserves the
633+ rare non-zero pixels (``plt.spy``-style).
634+ ds_reduction : {"max", "min", "mean", "mode", "first", "last", "var", "std"} | None, optional
635+ Downsample reduction used by the datashader path. Defaults to
636+ ``'max'`` when ``method='datashader'``. Ignored otherwise (a
637+ warning is emitted if set without ``method='datashader'``).
619638
620639 Notes
621640 -----
@@ -634,6 +653,20 @@ def render_images(
634653 """
635654 if grayscale and palette is not None :
636655 raise ValueError ("Cannot combine grayscale=True with palette." )
656+
657+ if method is not None and not isinstance (method , str ):
658+ raise TypeError ("Parameter 'method' must be a string." )
659+ if method is not None and method not in ("matplotlib" , "datashader" ):
660+ raise ValueError ("Parameter 'method' must be either 'matplotlib' or 'datashader'." )
661+ if ds_reduction is not None and not isinstance (ds_reduction , str ):
662+ raise TypeError ("Parameter 'ds_reduction' must be a string." )
663+ if ds_reduction is not None and ds_reduction not in get_args (_ImageDsReduction ):
664+ raise ValueError (
665+ f"Parameter 'ds_reduction' must be one of { get_args (_ImageDsReduction )} , got { ds_reduction !r} ."
666+ )
667+ if ds_reduction is not None and method != "datashader" :
668+ logger .warning ("Parameter 'ds_reduction' has no effect unless method='datashader'; ignoring." )
669+
637670 params_dict = _validate_image_render_params (
638671 self ._sdata ,
639672 element = element ,
@@ -699,6 +732,8 @@ def render_images(
699732 transfunc = transfunc ,
700733 grayscale = grayscale ,
701734 channels_as_legend = channels_as_legend ,
735+ method = method ,
736+ ds_reduction = ds_reduction ,
702737 )
703738 n_steps += 1
704739
0 commit comments