Skip to content

Commit a16dbca

Browse files
timtreisclaude
andcommitted
Fix span auto-scaling after NaN masking
When span=None, datashader auto-scales from non-NaN values. After masking zeros to NaN, the remaining minimum value (e.g. 2.0) would map to cmap(0.0) — the transparent-white entry — and min_alpha would override the alpha, producing opaque white. Fix: _mask_transparent_cmap_entries now returns the pre-masking span so ds.tf.shade uses the original data range, preventing re-scaling into the transparent bin. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e100145 commit a16dbca

2 files changed

Lines changed: 25 additions & 13 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,21 +3198,26 @@ def _mask_transparent_cmap_entries(
31983198
agg: DataArray,
31993199
cmap: str | list[str] | Colormap,
32003200
span: list[float] | tuple[float, float] | None,
3201-
) -> DataArray:
3201+
) -> tuple[DataArray, list[float] | tuple[float, float] | None]:
32023202
"""Set aggregate values to NaN where the colormap alpha is zero.
32033203
32043204
Datashader ignores the per-entry alpha channel of matplotlib colormaps,
32053205
so entries meant to be fully transparent (alpha=0) must be converted to
32063206
NaN — which datashader already renders as transparent — before shading.
32073207
See :issue:`376`.
3208+
3209+
Returns ``(masked_agg, span)`` where *span* is the data range computed
3210+
**before** masking. The caller must pass this span to ``ds.tf.shade``
3211+
to prevent datashader from auto-scaling the remaining values into the
3212+
transparent bin.
32083213
"""
32093214
if not isinstance(cmap, Colormap):
3210-
return agg
3215+
return agg, span
32113216

32123217
# Only the bottom entry is checked — this is what
32133218
# ``set_zero_in_cmap_to_transparent`` targets.
32143219
if cmap(0.0)[3] >= 1.0:
3215-
return agg
3220+
return agg, span
32163221

32173222
# For a ListedColormap with N colors, index 0 covers [0, 1/N) in
32183223
# normalised space. Compute the data value at the upper boundary of that bin.
@@ -3225,18 +3230,22 @@ def _mask_transparent_cmap_entries(
32253230
hi = float(agg.max())
32263231

32273232
if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi):
3228-
return agg
3233+
return agg, span
3234+
3235+
# Freeze the span before masking so datashader doesn't auto-scale
3236+
# the remaining values into the now-transparent first bin.
3237+
span = [lo, hi]
32293238

32303239
threshold = lo + frac * (hi - lo)
3231-
return agg.where(agg >= threshold)
3240+
return agg.where(agg >= threshold), span
32323241

32333242

32343243
def _datashader_map_aggregate_to_color(
32353244
agg: DataArray,
32363245
cmap: str | list[str] | ListedColormap,
32373246
color_key: list[str] | dict[str, str] | None = None,
32383247
min_alpha: float = 40,
3239-
span: None | list[float] = None,
3248+
span: list[float] | tuple[float, float] | None = None,
32403249
clip: bool = True,
32413250
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
32423251
"""ds.tf.shade() part, ensuring correct clipping behavior.
@@ -3248,7 +3257,8 @@ def _datashader_map_aggregate_to_color(
32483257
# in case we use datashader together with a Normalize object where clip=False
32493258
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
32503259
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
3251-
agg_in = _mask_transparent_cmap_entries(agg_in, cmap, span)
3260+
agg_in, span = _mask_transparent_cmap_entries(agg_in, cmap, span)
3261+
assert span is not None # guaranteed: we passed a non-None span above
32523262
img_in = ds.tf.shade(
32533263
agg_in,
32543264
cmap=cmap,
@@ -3285,7 +3295,7 @@ def _datashader_map_aggregate_to_color(
32853295
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
32863296
return stack
32873297

3288-
agg = _mask_transparent_cmap_entries(agg, cmap, span)
3298+
agg, span = _mask_transparent_cmap_entries(agg, cmap, span)
32893299
return ds.tf.shade(
32903300
agg,
32913301
cmap=cmap,

tests/pl/test_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,27 @@ def test_masks_zero_values_when_cmap_has_transparent_entry(self):
125125
data = np.array([[0.0, 1.0, 5.0], [0.0, 2.0, 10.0]])
126126
agg = xr.DataArray(data, dims=["y", "x"])
127127

128-
masked = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
128+
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
129129

130130
assert np.isnan(masked.values[0, 0])
131131
assert np.isnan(masked.values[1, 0])
132132
assert masked.values[0, 1] == 1.0
133133
assert masked.values[0, 2] == 5.0
134+
assert returned_span == [0.0, 10.0]
134135

135136
def test_no_effect_for_opaque_cmap(self):
136137
cmap = plt.get_cmap("viridis")
137138
data = np.array([[0.0, 5.0, 10.0]])
138139
agg = xr.DataArray(data, dims=["y", "x"])
139140

140-
masked = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
141+
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=[0.0, 10.0])
141142
np.testing.assert_array_equal(masked.values, data)
142143

143144
def test_no_effect_for_string_cmap(self):
144145
data = np.array([[0.0, 5.0, 10.0]])
145146
agg = xr.DataArray(data, dims=["y", "x"])
146147

147-
masked = _mask_transparent_cmap_entries(agg, "viridis", span=[0.0, 10.0])
148+
masked, _ = _mask_transparent_cmap_entries(agg, "viridis", span=[0.0, 10.0])
148149
np.testing.assert_array_equal(masked.values, data)
149150

150151
def test_datashader_shade_respects_transparent_cmap(self):
@@ -168,11 +169,12 @@ def test_span_none_with_zeros(self):
168169
data = np.array([[0.0, 3.0, 10.0]])
169170
agg = xr.DataArray(data, dims=["y", "x"])
170171

171-
masked = _mask_transparent_cmap_entries(agg, cmap, span=None)
172+
masked, returned_span = _mask_transparent_cmap_entries(agg, cmap, span=None)
172173

173174
assert np.isnan(masked.values[0, 0])
174175
assert masked.values[0, 1] == 3.0
175176
assert masked.values[0, 2] == 10.0
177+
assert returned_span == [0.0, 10.0], "span should be frozen from pre-masking range"
176178

177179
def test_all_nan_aggregate(self):
178180
"""All-NaN aggregate is returned unchanged."""
@@ -181,7 +183,7 @@ def test_all_nan_aggregate(self):
181183
data = np.array([[np.nan, np.nan]])
182184
agg = xr.DataArray(data, dims=["y", "x"])
183185

184-
masked = _mask_transparent_cmap_entries(agg, cmap, span=None)
186+
masked, _ = _mask_transparent_cmap_entries(agg, cmap, span=None)
185187
np.testing.assert_array_equal(np.isnan(masked.values), np.isnan(data))
186188

187189

0 commit comments

Comments
 (0)