Skip to content

Commit 8a6a33f

Browse files
authored
Add density mode to render_points (#679)
1 parent 75cadce commit 8a6a33f

10 files changed

Lines changed: 246 additions & 6 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _ds_shade_continuous(
227227
na_color_hex: str,
228228
spread_px: int | None = None,
229229
ds_reduction: _DsReduction | None = None,
230+
how: str = "linear",
230231
) -> tuple[Any, Any | None, tuple[Any, Any] | None]:
231232
"""Shade a continuous datashader aggregate, optionally applying spread and NaN coloring.
232233
@@ -255,6 +256,7 @@ def _ds_shade_continuous(
255256
min_alpha=_convert_alpha_to_datashader_range(alpha),
256257
span=color_span,
257258
clip=norm.clip,
259+
how=how,
258260
)
259261
shaded = _apply_user_alpha(shaded, alpha)
260262

@@ -278,6 +280,8 @@ def _ds_shade_categorical(
278280
color_vector: Any,
279281
alpha: float,
280282
spread_px: int | None = None,
283+
how: str = "linear",
284+
density: bool = False,
281285
) -> Any:
282286
"""Shade a categorical or no-color datashader aggregate."""
283287
ds_cmap = None
@@ -286,12 +290,20 @@ def _ds_shade_categorical(
286290
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
287291
ds_cmap = _hex_no_alpha(ds_cmap)
288292

293+
# The default min_alpha (~254) is a near-full-opacity floor — right for scatter
294+
# plots, but it collapses the count-driven alpha range and makes categorical
295+
# density read as a flat hue cloud. Drop the floor under density so per-pixel
296+
# alpha can actually encode count. A small non-zero floor (~15%) keeps the
297+
# sparse edges visible under density_how="linear" instead of vanishing.
298+
min_alpha = 40.0 if density else _convert_alpha_to_datashader_range(alpha)
299+
289300
agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg
290301
shaded = _datashader_map_aggregate_to_color(
291302
agg_to_shade,
292303
cmap=ds_cmap,
293304
color_key=color_key,
294-
min_alpha=_convert_alpha_to_datashader_range(alpha),
305+
min_alpha=min_alpha,
306+
how=how,
295307
)
296308
return _apply_user_alpha(shaded, alpha)
297309

src/spatialdata_plot/pl/basic.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ def render_points(
385385
colorbar: bool | str | None = "auto",
386386
colorbar_params: dict[str, object] | None = None,
387387
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
388+
density: bool = False,
389+
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear",
388390
transfunc: Callable[[float], float] | None = None,
389391
) -> sd.SpatialData:
390392
"""
@@ -455,13 +457,38 @@ def render_points(
455457
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
456458
datashader_reduction : Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, optional
457459
Reduction method for datashader when coloring by continuous values. When ``None``, defaults to ``"sum"``.
460+
density : bool, default False
461+
Render the points as a 2-D count density via datashader instead of plotting individual markers.
462+
When ``True``, ``method`` is forced to ``"datashader"`` (passing ``method="matplotlib"`` raises).
463+
Density supports ``color=None`` (plain density) or a categorical ``color`` column (per-category
464+
density via :func:`datashader.by`). A continuous ``color`` column or a literal color value
465+
(e.g. ``"red"``) raises an error. Under ``density=True`` the following parameters are ignored
466+
(with a warning if explicitly set): ``size``, ``transfunc``, ``norm.vmin/vmax``, and
467+
``datashader_reduction``.
468+
density_how : Literal["linear", "log", "cbrt", "eq_hist"], default "linear"
469+
How datashader maps aggregated counts to color intensity. ``"linear"`` (default) keeps the
470+
colorbar axis as a count; ``"log"`` and ``"cbrt"`` compress dynamic range; ``"eq_hist"``
471+
equalizes the histogram (rank-based, surfaces the most structure but the colorbar axis is
472+
no longer a count). Ignored when ``density=False``.
458473
transfunc : Callable[[float], float] | None, optional
459474
Optional transformation applied to the continuous color vector before normalization and colormap mapping.
460475
461476
Returns
462477
-------
463478
sd.SpatialData
464479
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
480+
481+
Examples
482+
--------
483+
Plain density of all transcripts:
484+
485+
>>> sdata.pl.render_points("transcripts", density=True).pl.show()
486+
487+
Per-gene density with a categorical palette:
488+
489+
>>> sdata.pl.render_points(
490+
... "transcripts", color="gene", groups=["Gad1", "Slc17a7"], palette="tab20", density=True
491+
... ).pl.show()
465492
"""
466493
params_dict = _validate_points_render_params(
467494
self._sdata,
@@ -480,6 +507,10 @@ def render_points(
480507
colorbar=colorbar,
481508
colorbar_params=colorbar_params,
482509
gene_symbols=gene_symbols,
510+
density=density,
511+
density_how=density_how,
512+
transfunc=transfunc,
513+
method=method,
483514
)
484515

485516
if method is not None:
@@ -488,6 +519,9 @@ def render_points(
488519
if method not in ["matplotlib", "datashader"]:
489520
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
490521

522+
if density and method is None:
523+
method = "datashader"
524+
491525
sdata = self._copy()
492526
sdata = _verify_plotting_tree(sdata)
493527
n_steps = len(sdata.plotting_tree.keys())
@@ -515,6 +549,8 @@ def render_points(
515549
ds_reduction=param_values["ds_reduction"],
516550
colorbar=param_values["colorbar"],
517551
colorbar_params=param_values["colorbar_params"],
552+
density=density,
553+
density_how=density_how,
518554
)
519555
n_steps += 1
520556

src/spatialdata_plot/pl/render.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,44 @@ def _warn_groups_ignored_continuous(
170170
)
171171

172172

173+
def _is_categorical_like_dtype(dtype: Any) -> bool:
174+
return (
175+
isinstance(dtype, pd.CategoricalDtype)
176+
or pd.api.types.is_object_dtype(dtype)
177+
or pd.api.types.is_string_dtype(dtype)
178+
)
179+
180+
181+
def _reject_continuous_color_under_density(
182+
sdata_filt: sd.SpatialData,
183+
element: str,
184+
col_for_color: str | None,
185+
color_source_vector: Any,
186+
color_vector: Any,
187+
) -> None:
188+
"""Raise before any materialization if density+continuous-color was requested.
189+
190+
``color_source_vector`` is only populated by ``_set_color_source_vec`` for the categorical
191+
branch, so a non-None value is sufficient to accept the call. Otherwise we read the dtype
192+
from the dask source (points element column) or the pre-computed color vector — neither
193+
forces a ``.compute()``.
194+
"""
195+
if col_for_color is None or color_source_vector is not None:
196+
return
197+
points_columns = sdata_filt.points[element].columns
198+
if col_for_color in points_columns:
199+
dtype = sdata_filt.points[element][col_for_color].dtype
200+
else:
201+
dtype = getattr(color_vector, "dtype", None)
202+
if dtype is None or _is_categorical_like_dtype(dtype):
203+
return
204+
raise ValueError(
205+
f"density=True is only supported with no color or a categorical color column; "
206+
f"got continuous column {col_for_color!r}. To color a density plot by a continuous "
207+
f"variable, set density=False and use method='datashader' with datashader_reduction=."
208+
)
209+
210+
173211
def _warn_missing_groups(
174212
groups: str | list[str],
175213
color_source_vector: pd.Categorical,
@@ -950,7 +988,10 @@ def _render_points(
950988

951989
method = render_params.method
952990

953-
if method is None:
991+
if render_params.density:
992+
method = "datashader"
993+
_reject_continuous_color_under_density(sdata_filt, element, col_for_color, color_source_vector, color_vector)
994+
elif method is None:
954995
method = "datashader" if n_points > 10000 else "matplotlib"
955996

956997
_default_reduction: _DsReduction = "sum"
@@ -960,7 +1001,11 @@ def _render_points(
9601001

9611002
# NOTE: s in matplotlib is in units of points**2
9621003
# use dpi/100 as a factor for cases where dpi!=100
963-
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
1004+
# Under density, spreading would smear the count signal across pixels and
1005+
# distort apparent density at sparse edges, so disable it unconditionally.
1006+
px: int | None = (
1007+
None if render_params.density else int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
1008+
)
9641009

9651010
# Apply transformations and materialize to pandas immediately so
9661011
# datashader aggregates without dask scheduler overhead. See #379.
@@ -1045,14 +1090,22 @@ def _render_points(
10451090
):
10461091
color_vector = np.asarray([_hex_no_alpha(c) for c in color_vector])
10471092

1093+
shade_how = render_params.density_how if render_params.density else "linear"
1094+
# Plain density (no color column) must use the user-facing cmap as a sequential
1095+
# gradient over counts; the categorical path collapses to a single color and only
1096+
# modulates alpha, which renders as a flat hue regardless of density.
1097+
plain_density = render_params.density and col_for_color is None
1098+
10481099
nan_shaded = None
1049-
if color_by_categorical or col_for_color is None:
1100+
if not plain_density and (color_by_categorical or col_for_color is None):
10501101
shaded = _ds_shade_categorical(
10511102
agg,
10521103
color_key,
10531104
color_vector,
10541105
render_params.alpha,
10551106
spread_px=px,
1107+
how=shade_how,
1108+
density=render_params.density,
10561109
)
10571110
else:
10581111
shaded, nan_shaded, reduction_bounds = _ds_shade_continuous(
@@ -1066,6 +1119,7 @@ def _render_points(
10661119
na_color_hex,
10671120
spread_px=px,
10681121
ds_reduction=render_params.ds_reduction,
1122+
how=shade_how,
10691123
)
10701124

10711125
_render_ds_image(

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ class PointsRenderParams:
268268
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
269269
colorbar: bool | str | None = "auto"
270270
colorbar_params: dict[str, object] | None = None
271+
density: bool = False
272+
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear"
271273

272274

273275
@dataclass

src/spatialdata_plot/pl/utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import math
44
import os
5+
import warnings
56
from collections import OrderedDict
6-
from collections.abc import Iterable, Mapping, Sequence
7+
from collections.abc import Callable, Iterable, Mapping, Sequence
78
from copy import copy
89
from functools import partial
910
from pathlib import Path
@@ -2820,7 +2821,17 @@ def _validate_points_render_params(
28202821
colorbar: bool | str | None,
28212822
colorbar_params: dict[str, object] | None,
28222823
gene_symbols: str | None = None,
2824+
density: bool = False,
2825+
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear",
2826+
transfunc: Callable[[float], float] | None = None,
2827+
method: str | None = None,
28232828
) -> dict[str, dict[str, Any]]:
2829+
if not isinstance(density, bool):
2830+
raise TypeError("Parameter 'density' must be a bool.")
2831+
allowed_how = ("linear", "log", "cbrt", "eq_hist")
2832+
if density_how not in allowed_how:
2833+
raise ValueError(f"Parameter 'density_how' must be one of {allowed_how}; got {density_how!r}.")
2834+
28242835
param_dict: dict[str, Any] = {
28252836
"sdata": sdata,
28262837
"element": element,
@@ -2840,6 +2851,47 @@ def _validate_points_render_params(
28402851
}
28412852
param_dict = _type_check_params(param_dict, "points")
28422853

2854+
if density:
2855+
if method == "matplotlib":
2856+
raise ValueError(
2857+
"density=True requires the datashader backend; got method='matplotlib'. "
2858+
"Either drop method= or set method='datashader'."
2859+
)
2860+
# Literal color (resolved into param_dict["color"] as a Color instance, with
2861+
# col_for_color set to None) is ambiguous with density: it could mean a
2862+
# single-hue cmap or a one-entry palette. Force the user to choose.
2863+
if param_dict["color"] is not None and param_dict["col_for_color"] is None:
2864+
raise ValueError(
2865+
"density=True with a literal color is ambiguous. Pass cmap= to recolor the "
2866+
"density, or palette= to assign a categorical color, but not color=<literal>."
2867+
)
2868+
# Warn-and-ignore: these parameters do not interact meaningfully with a
2869+
# count-based density and are silently dropped to keep the API consistent.
2870+
if size != 1.0:
2871+
warnings.warn(
2872+
"size is ignored when density=True; spreading would distort the count signal.",
2873+
UserWarning,
2874+
stacklevel=3,
2875+
)
2876+
if transfunc is not None:
2877+
warnings.warn(
2878+
"transfunc is ignored when density=True (no continuous color vector to transform).",
2879+
UserWarning,
2880+
stacklevel=3,
2881+
)
2882+
if isinstance(norm, Normalize) and (norm.vmin is not None or norm.vmax is not None):
2883+
warnings.warn(
2884+
"norm.vmin/vmax are ignored when density=True; use density_how= to control intensity mapping.",
2885+
UserWarning,
2886+
stacklevel=3,
2887+
)
2888+
if ds_reduction is not None:
2889+
warnings.warn(
2890+
"datashader_reduction is ignored when density=True; counts are forced.",
2891+
UserWarning,
2892+
stacklevel=3,
2893+
)
2894+
28432895
element_params: dict[str, dict[str, Any]] = {}
28442896
for el in param_dict["element"]:
28452897
# ensure that the element exists in the SpatialData object
@@ -3715,11 +3767,17 @@ def _datashader_map_aggregate_to_color(
37153767
min_alpha: float = 40,
37163768
span: None | list[float] = None,
37173769
clip: bool = True,
3770+
how: str = "linear",
37183771
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
37193772
"""ds.tf.shade() part, ensuring correct clipping behavior.
37203773
37213774
If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
37223775
This ensures the correct clipping behavior, because else datashader would always automatically clip.
3776+
3777+
``how`` controls the count-to-color mapping passed to :func:`datashader.transfer_functions.shade`
3778+
(``"linear"`` by default; ``"log"``/``"cbrt"``/``"eq_hist"`` compress dynamic range). The split-shade
3779+
branch used for ``norm.clip=False`` always uses ``"linear"`` since per-segment shading would otherwise
3780+
interact poorly with rank-based mappings.
37233781
"""
37243782
if not clip and isinstance(cmap, Colormap) and span is not None:
37253783
# in case we use datashader together with a Normalize object where clip=False
@@ -3768,7 +3826,7 @@ def _datashader_map_aggregate_to_color(
37683826
color_key=color_key,
37693827
min_alpha=min_alpha,
37703828
span=span,
3771-
how="linear",
3829+
how=how,
37723830
)
37733831
return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span)
37743832

47.3 KB
Loading
55.6 KB
Loading
54.6 KB
Loading

tests/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,34 @@ def sdata_blobs() -> SpatialData:
8080
return blobs()
8181

8282

83+
@pytest.fixture()
84+
def sdata_dense_points() -> SpatialData:
85+
"""Dense (~20k) multi-cluster points dataset for density-rendering visual tests.
86+
87+
The blobs fixture is too sparse (~200 points across 500x500) for density to render
88+
meaningfully without spreading; this fixture provides a Gaussian-cluster cloud with
89+
a categorical ``gene`` column so the per-category density branch is exercised too.
90+
"""
91+
rng = get_standard_RNG()
92+
n_per_cluster = 20000
93+
centers = [(120, 120), (380, 150), (250, 380)]
94+
genes = ["gene_a", "gene_b", "gene_c"]
95+
xs, ys, gs = [], [], []
96+
for (cx, cy), gene in zip(centers, genes, strict=True):
97+
xs.append(rng.normal(loc=cx, scale=18, size=n_per_cluster))
98+
ys.append(rng.normal(loc=cy, scale=18, size=n_per_cluster))
99+
gs.extend([gene] * n_per_cluster)
100+
df = pd.DataFrame(
101+
{
102+
"x": np.concatenate(xs).clip(0, 500),
103+
"y": np.concatenate(ys).clip(0, 500),
104+
"gene": pd.Categorical(gs, categories=genes),
105+
}
106+
)
107+
points = PointsModel.parse(df)
108+
return SpatialData(points={"dense_points": points})
109+
110+
83111
@pytest.fixture()
84112
def sdata_blobs_str() -> SpatialData:
85113
return blobs(n_channels=5, c_coords=["c1", "c2", "c3", "c4", "c5"])

0 commit comments

Comments
 (0)