Skip to content

Commit a4337eb

Browse files
timtreisclaude
andcommitted
Fix set_zero_in_cmap_to_transparent not working with datashader (#376)
Datashader's tf.shade() strips the alpha channel from matplotlib colormaps, so entries with alpha=0 (from set_zero_in_cmap_to_transparent) rendered as opaque white instead of transparent. Fix by masking aggregate values that map to transparent cmap entries to NaN before shading — NaN is datashader's native transparency mechanism. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 303140c commit a4337eb

2 files changed

Lines changed: 114 additions & 1 deletion

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,6 +3194,43 @@ def _prepare_transformation(
31943194
return trans, trans_data
31953195

31963196

3197+
def _mask_transparent_cmap_entries(
3198+
agg: DataArray,
3199+
cmap: str | list[str] | Colormap,
3200+
span: list[float] | tuple[float, float] | None,
3201+
) -> DataArray:
3202+
"""Set aggregate values to NaN where the colormap alpha is zero.
3203+
3204+
Datashader ignores the per-entry alpha channel of matplotlib colormaps,
3205+
so entries meant to be fully transparent (alpha=0) must be converted to
3206+
NaN — which datashader already renders as transparent — before shading.
3207+
See :issue:`376`.
3208+
"""
3209+
if not isinstance(cmap, Colormap):
3210+
return agg
3211+
3212+
# Only the bottom entry is checked — this is what
3213+
# ``set_zero_in_cmap_to_transparent`` targets.
3214+
if cmap(0.0)[3] >= 1.0:
3215+
return agg
3216+
3217+
# For a ListedColormap with N colors, index 0 covers [0, 1/N) in
3218+
# normalised space. Compute the data value at the upper boundary of that bin.
3219+
frac = 1.0 / cmap.N
3220+
3221+
if span is not None:
3222+
lo, hi = float(span[0]), float(span[1])
3223+
else:
3224+
lo = float(agg.min())
3225+
hi = float(agg.max())
3226+
3227+
if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi):
3228+
return agg
3229+
3230+
threshold = lo + frac * (hi - lo)
3231+
return agg.where(agg >= threshold)
3232+
3233+
31973234
def _datashader_map_aggregate_to_color(
31983235
agg: DataArray,
31993236
cmap: str | list[str] | ListedColormap,
@@ -3211,6 +3248,7 @@ def _datashader_map_aggregate_to_color(
32113248
# in case we use datashader together with a Normalize object where clip=False
32123249
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
32133250
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
3251+
agg_in = _mask_transparent_cmap_entries(agg_in, cmap, span)
32143252
img_in = ds.tf.shade(
32153253
agg_in,
32163254
cmap=cmap,
@@ -3247,6 +3285,7 @@ def _datashader_map_aggregate_to_color(
32473285
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
32483286
return stack
32493287

3288+
agg = _mask_transparent_cmap_entries(agg, cmap, span)
32503289
return ds.tf.shade(
32513290
agg,
32523291
cmap=cmap,

tests/pl/test_utils.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
import pandas as pd
55
import pytest
66
import scanpy as sc
7+
import xarray as xr
78
from spatialdata import SpatialData
89

910
import spatialdata_plot
10-
from spatialdata_plot.pl.utils import _get_subplots
11+
from spatialdata_plot.pl.utils import (
12+
_datashader_map_aggregate_to_color,
13+
_get_subplots,
14+
_mask_transparent_cmap_entries,
15+
set_zero_in_cmap_to_transparent,
16+
)
1117
from tests.conftest import DPI, PlotTester, PlotTesterMeta
1218

1319
sc.pl.set_rcParams_defaults()
@@ -90,6 +96,74 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]):
9096
assert spatialdata_plot.pl.utils._is_color_like(color) == result
9197

9298

99+
class TestMaskTransparentCmapEntries:
100+
"""Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""
101+
102+
def test_masks_zero_values_when_cmap_has_transparent_entry(self):
103+
cmap = set_zero_in_cmap_to_transparent("viridis")
104+
data = np.array([[0.0, 1.0, 5.0], [0.0, 2.0, 10.0]])
105+
agg = xr.DataArray(data, dims=["y", "x"])
106+
107+
masked = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
108+
109+
assert np.isnan(masked.values[0, 0])
110+
assert np.isnan(masked.values[1, 0])
111+
assert masked.values[0, 1] == 1.0
112+
assert masked.values[0, 2] == 5.0
113+
114+
def test_no_effect_for_opaque_cmap(self):
115+
cmap = plt.get_cmap("viridis")
116+
data = np.array([[0.0, 5.0, 10.0]])
117+
agg = xr.DataArray(data, dims=["y", "x"])
118+
119+
masked = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
120+
np.testing.assert_array_equal(masked.values, data)
121+
122+
def test_no_effect_for_string_cmap(self):
123+
data = np.array([[0.0, 5.0, 10.0]])
124+
agg = xr.DataArray(data, dims=["y", "x"])
125+
126+
masked = _mask_transparent_cmap_entries(agg, "viridis", span=[0.0, 10.0])
127+
np.testing.assert_array_equal(masked.values, data)
128+
129+
def test_datashader_shade_respects_transparent_cmap(self):
130+
"""End-to-end: _datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries."""
131+
cmap = set_zero_in_cmap_to_transparent("viridis")
132+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
133+
agg = xr.DataArray(data, dims=["y", "x"])
134+
135+
result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254, span=[0.0, 10.0])
136+
img = result.values if hasattr(result, "values") else result
137+
138+
alpha_at_zero = (int(img[0, 0]) >> 24) & 0xFF
139+
alpha_at_five = (int(img[0, 1]) >> 24) & 0xFF
140+
141+
assert alpha_at_zero == 0, f"Expected alpha=0 at value=0.0, got {alpha_at_zero}"
142+
assert alpha_at_five > 0, f"Expected non-zero alpha at value=5.0, got {alpha_at_five}"
143+
144+
def test_span_none_with_zeros(self):
145+
"""Masking works when span is inferred from the aggregate (span=None)."""
146+
cmap = set_zero_in_cmap_to_transparent("viridis")
147+
data = np.array([[0.0, 3.0, 10.0]])
148+
agg = xr.DataArray(data, dims=["y", "x"])
149+
150+
masked = _mask_transparent_cmap_entries(agg, cmap, span=None)
151+
152+
assert np.isnan(masked.values[0, 0])
153+
assert masked.values[0, 1] == 3.0
154+
assert masked.values[0, 2] == 10.0
155+
156+
def test_all_nan_aggregate(self):
157+
"""All-NaN aggregate is returned unchanged."""
158+
159+
cmap = set_zero_in_cmap_to_transparent("viridis")
160+
data = np.array([[np.nan, np.nan]])
161+
agg = xr.DataArray(data, dims=["y", "x"])
162+
163+
masked = _mask_transparent_cmap_entries(agg, cmap, span=None)
164+
np.testing.assert_array_equal(np.isnan(masked.values), np.isnan(data))
165+
166+
93167
def test_extract_scalar_value():
94168
"""Test the new _extract_scalar_value function for robust numeric conversion."""
95169

0 commit comments

Comments
 (0)