Skip to content

Commit 91c5ce3

Browse files
timtreisclaude
andcommitted
Replace NaN-masking with RGBA post-processing for cmap alpha
The NaN-masking approach was fundamentally flawed: masking values to NaN caused datashader to auto-scale the remaining values, shifting other shapes' colors and potentially making them white. New approach: let datashader shade normally (preserving all colors and spans), then post-process the RGBA output to apply the cmap's alpha channel. This: - Does not modify aggregate values or spans - Preserves exact colors for all non-transparent shapes - Works for any cmap with transparent entries at any position - Handles span=None correctly (no auto-scaling issues) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a16dbca commit 91c5ce3

2 files changed

Lines changed: 96 additions & 77 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,58 +3194,75 @@ def _prepare_transformation(
31943194
return trans, trans_data
31953195

31963196

3197-
def _mask_transparent_cmap_entries(
3197+
def _apply_cmap_alpha_to_datashader_result(
3198+
result: Any,
31983199
agg: DataArray,
31993200
cmap: str | list[str] | Colormap,
32003201
span: list[float] | tuple[float, float] | None,
3201-
) -> tuple[DataArray, list[float] | tuple[float, float] | None]:
3202-
"""Set aggregate values to NaN where the colormap alpha is zero.
3202+
) -> Any:
3203+
"""Apply the colormap's alpha channel to a datashader RGBA result.
32033204
32043205
Datashader ignores the per-entry alpha channel of matplotlib colormaps,
3205-
so entries meant to be fully transparent (alpha=0) must be converted to
3206-
NaN — which datashader already renders as transparent — before shading.
3207-
See :issue:`376`.
3208-
3209-
Returns ``(masked_agg, span)`` where *span* is the data range computed
3210-
**before** masking. The caller must pass this span to ``ds.tf.shade``
3211-
to prevent datashader from auto-scaling the remaining values into the
3212-
transparent bin.
3206+
so pixels that the cmap marks as transparent (alpha=0) are rendered
3207+
opaque. This function post-processes the shaded RGBA output to restore
3208+
the cmap's intended transparency. See :issue:`376`.
32133209
"""
32143210
if not isinstance(cmap, Colormap):
3215-
return agg, span
3211+
return result
3212+
3213+
# Quick check: does this cmap have any transparent entries?
3214+
test_vals = np.linspace(0, 1, min(cmap.N, 256))
3215+
cmap_alphas = cmap(test_vals)[:, 3]
3216+
if np.all(cmap_alphas >= 1.0):
3217+
return result
3218+
3219+
# Get or ensure we have an (H, W, 4) uint8 array
3220+
if hasattr(result, "values"):
3221+
# datashader Image — uint32 packed, convert via to_numpy()
3222+
rgba = result.to_numpy().base
3223+
if rgba is None:
3224+
return result
3225+
else:
3226+
rgba = result
32163227

3217-
# Only the bottom entry is checked — this is what
3218-
# ``set_zero_in_cmap_to_transparent`` targets.
3219-
if cmap(0.0)[3] >= 1.0:
3220-
return agg, span
3228+
if rgba.ndim != 3 or rgba.shape[2] != 4:
3229+
return result
32213230

3222-
# For a ListedColormap with N colors, index 0 covers [0, 1/N) in
3223-
# normalised space. Compute the data value at the upper boundary of that bin.
3224-
frac = 1.0 / cmap.N
3231+
# Normalise aggregate values to [0, 1] using the same span datashader used
3232+
agg_vals = agg.values.astype(np.float64)
3233+
valid = np.isfinite(agg_vals)
3234+
if not valid.any():
3235+
return result
32253236

32263237
if span is not None:
32273238
lo, hi = float(span[0]), float(span[1])
32283239
else:
3229-
lo = float(agg.min())
3230-
hi = float(agg.max())
3240+
lo = float(np.nanmin(agg_vals))
3241+
hi = float(np.nanmax(agg_vals))
32313242

32323243
if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi):
3233-
return agg, span
3244+
return result
3245+
3246+
normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0)
32343247

3235-
# Freeze the span before masking so datashader doesn't auto-scale
3236-
# the remaining values into the now-transparent first bin.
3237-
span = [lo, hi]
3248+
# Look up cmap alpha for each pixel
3249+
desired_alpha = cmap(normed)[:, :, 3]
32383250

3239-
threshold = lo + frac * (hi - lo)
3240-
return agg.where(agg >= threshold), span
3251+
# Zero out pixels where the cmap wants transparency
3252+
transparent = valid & (desired_alpha < 1.0)
3253+
if transparent.any():
3254+
# Scale the existing alpha by the cmap's alpha
3255+
rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8)
3256+
3257+
return result
32413258

32423259

32433260
def _datashader_map_aggregate_to_color(
32443261
agg: DataArray,
32453262
cmap: str | list[str] | ListedColormap,
32463263
color_key: list[str] | dict[str, str] | None = None,
32473264
min_alpha: float = 40,
3248-
span: list[float] | tuple[float, float] | None = None,
3265+
span: None | list[float] = None,
32493266
clip: bool = True,
32503267
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
32513268
"""ds.tf.shade() part, ensuring correct clipping behavior.
@@ -3257,8 +3274,6 @@ def _datashader_map_aggregate_to_color(
32573274
# in case we use datashader together with a Normalize object where clip=False
32583275
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
32593276
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
3260-
agg_in, span = _mask_transparent_cmap_entries(agg_in, cmap, span)
3261-
assert span is not None # guaranteed: we passed a non-None span above
32623277
img_in = ds.tf.shade(
32633278
agg_in,
32643279
cmap=cmap,
@@ -3293,17 +3308,18 @@ def _datashader_map_aggregate_to_color(
32933308
img_over = img_over.to_numpy().base
32943309
if img_over is not None:
32953310
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
3296-
return stack
32973311

3298-
agg, span = _mask_transparent_cmap_entries(agg, cmap, span)
3299-
return ds.tf.shade(
3312+
return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span)
3313+
3314+
result = ds.tf.shade(
33003315
agg,
33013316
cmap=cmap,
33023317
color_key=color_key,
33033318
min_alpha=min_alpha,
33043319
span=span,
33053320
how="linear",
33063321
)
3322+
return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span)
33073323

33083324

33093325
def _hex_no_alpha(hex: str) -> str:

tests/pl/test_utils.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
import spatialdata_plot
1111
from spatialdata_plot.pl.utils import (
12+
_apply_cmap_alpha_to_datashader_result,
1213
_datashader_map_aggregate_to_color,
1314
_get_subplots,
14-
_mask_transparent_cmap_entries,
1515
set_zero_in_cmap_to_transparent,
1616
)
1717
from tests.conftest import DPI, PlotTester, PlotTesterMeta
@@ -117,74 +117,77 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]):
117117
assert spatialdata_plot.pl.utils._is_color_like(color) == result
118118

119119

120-
class TestMaskTransparentCmapEntries:
120+
class TestCmapAlphaDatashader:
121121
"""Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""
122122

123-
def test_masks_zero_values_when_cmap_has_transparent_entry(self):
123+
def test_transparent_pixels_get_alpha_zero(self):
124+
"""Post-processing sets alpha=0 for pixels mapping to transparent cmap entries."""
125+
import datashader as ds
126+
124127
cmap = set_zero_in_cmap_to_transparent("viridis")
125-
data = np.array([[0.0, 1.0, 5.0], [0.0, 2.0, 10.0]])
128+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
126129
agg = xr.DataArray(data, dims=["y", "x"])
127130

128-
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
131+
shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
132+
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
133+
rgba = result.to_numpy().base if hasattr(result, "to_numpy") else result
134+
135+
assert rgba[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {rgba[0, 0, 3]}"
136+
assert rgba[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"
137+
assert rgba[0, 2, 3] > 0, "Expected non-zero alpha at value=10.0"
129138

130-
assert np.isnan(masked.values[0, 0])
131-
assert np.isnan(masked.values[1, 0])
132-
assert masked.values[0, 1] == 1.0
133-
assert masked.values[0, 2] == 5.0
134-
assert returned_span == [0.0, 10.0]
139+
def test_opaque_cmap_unchanged(self):
140+
"""Post-processing is a no-op for fully opaque cmaps."""
141+
import datashader as ds
135142

136-
def test_no_effect_for_opaque_cmap(self):
137143
cmap = plt.get_cmap("viridis")
138-
data = np.array([[0.0, 5.0, 10.0]])
144+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
139145
agg = xr.DataArray(data, dims=["y", "x"])
140146

141-
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
142-
np.testing.assert_array_equal(masked.values, data)
147+
shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
148+
rgba_before = shaded.to_numpy().base.copy()
149+
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
150+
rgba_after = result.to_numpy().base if hasattr(result, "to_numpy") else result
151+
np.testing.assert_array_equal(rgba_before, rgba_after)
143152

144-
def test_no_effect_for_string_cmap(self):
145-
data = np.array([[0.0, 5.0, 10.0]])
153+
def test_string_cmap_passthrough(self):
154+
"""Post-processing is a no-op for string cmaps (early return)."""
155+
dummy_rgba = np.zeros((2, 3, 4), dtype=np.uint8)
156+
dummy_rgba[:, :, 3] = 200
157+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
146158
agg = xr.DataArray(data, dims=["y", "x"])
147159

148-
masked, _ = _mask_transparent_cmap_entries(agg, "viridis", span=[0.0, 10.0])
149-
np.testing.assert_array_equal(masked.values, data)
160+
result = _apply_cmap_alpha_to_datashader_result(dummy_rgba, agg, "viridis", span=[0.0, 10.0])
161+
np.testing.assert_array_equal(result, dummy_rgba)
150162

151-
def test_datashader_shade_respects_transparent_cmap(self):
152-
"""End-to-end: _datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries."""
163+
def test_end_to_end_datashader_map(self):
164+
"""_datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries."""
153165
cmap = set_zero_in_cmap_to_transparent("viridis")
154166
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
155167
agg = xr.DataArray(data, dims=["y", "x"])
156168

157169
result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254, span=[0.0, 10.0])
158-
img = result.values if hasattr(result, "values") else result
159-
160-
alpha_at_zero = (int(img[0, 0]) >> 24) & 0xFF
161-
alpha_at_five = (int(img[0, 1]) >> 24) & 0xFF
170+
img = result.to_numpy().base if hasattr(result, "to_numpy") else result
162171

163-
assert alpha_at_zero == 0, f"Expected alpha=0 at value=0.0, got {alpha_at_zero}"
164-
assert alpha_at_five > 0, f"Expected non-zero alpha at value=5.0, got {alpha_at_five}"
172+
assert img[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {img[0, 0, 3]}"
173+
assert img[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"
165174

166-
def test_span_none_with_zeros(self):
167-
"""Masking works when span is inferred from the aggregate (span=None)."""
175+
def test_span_none_preserves_colors(self):
176+
"""With span=None, non-transparent shapes keep their correct colors."""
168177
cmap = set_zero_in_cmap_to_transparent("viridis")
169-
data = np.array([[0.0, 3.0, 10.0]])
178+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
170179
agg = xr.DataArray(data, dims=["y", "x"])
171180

172-
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=None)
173-
174-
assert np.isnan(masked.values[0, 0])
175-
assert masked.values[0, 1] == 3.0
176-
assert masked.values[0, 2] == 10.0
177-
assert returned_span == [0.0, 10.0], "span should be frozen from pre-masking range"
178-
179-
def test_all_nan_aggregate(self):
180-
"""All-NaN aggregate is returned unchanged."""
181-
182-
cmap = set_zero_in_cmap_to_transparent("viridis")
183-
data = np.array([[np.nan, np.nan]])
184-
agg = xr.DataArray(data, dims=["y", "x"])
181+
result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254)
182+
img = result.to_numpy().base if hasattr(result, "to_numpy") else result
185183

186-
masked, _ = _mask_transparent_cmap_entries(agg, cmap, span=None)
187-
np.testing.assert_array_equal(np.isnan(masked.values), np.isnan(data))
184+
# value=0 should be transparent
185+
assert img[0, 0, 3] == 0
186+
# value=5 and value=10 should be opaque with correct viridis colors (not white)
187+
assert img[0, 1, 3] > 0
188+
assert img[0, 2, 3] > 0
189+
# The non-transparent pixels should NOT be white (R=255,G=255,B=255)
190+
assert not (img[0, 1, 0] == 255 and img[0, 1, 1] == 255 and img[0, 1, 2] == 255)
188191

189192

190193
def test_extract_scalar_value():

0 commit comments

Comments
 (0)