Skip to content

Commit bb978fb

Browse files
timtreisclaude
andcommitted
Fix render_points datashader pipeline: dead code, silent failures, and fragile alignment
- Forward `default_reduction` to aggregation instead of ignoring it (_datashader.py) - Warn when `ds_reduction` is set for categorical data (always uses count) - Warn when `groups` is set with continuous color column (silently ignored) - Warn on color_vector/cat_series length mismatch in `_build_datashader_color_key` - Only set cmap fallback in `_ds_shade_categorical` when no color_key is present - Add 54 unit tests covering all fixes and datashader reduction behavior Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent edca5a5 commit bb978fb

3 files changed

Lines changed: 539 additions & 2 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def _build_datashader_color_key(
6262
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
6363
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
6464
colors_arr = np.asarray(color_vector, dtype=object)
65+
if len(colors_arr) != len(cat_series.codes):
66+
logger.warning(
67+
f"color_vector length ({len(colors_arr)}) does not match categorical series length "
68+
f"({len(cat_series.codes)}); some categories may receive the na_color fallback."
69+
)
6570
first_color: dict[str, str] = {}
6671
for code, color in zip(cat_series.codes, colors_arr, strict=False):
6772
if code < 0:
@@ -119,6 +124,11 @@ def _agg_call(element: Any, agg_func: Any) -> Any:
119124

120125
if col_for_color is not None:
121126
if color_by_categorical:
127+
if ds_reduction is not None:
128+
logger.warning(
129+
f'ds_reduction="{ds_reduction}" is ignored for categorical data; '
130+
"categorical aggregation always uses count."
131+
)
122132
transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color])
123133
agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count()))
124134
else:
@@ -127,7 +137,9 @@ def _agg_call(element: Any, agg_func: Any) -> Any:
127137
f'Using the datashader reduction "{reduction_name}". "max" will give an output '
128138
"very close to the matplotlib result."
129139
)
130-
agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type)
140+
agg = _datashader_aggregate_with_function(
141+
reduction_name, cvs, transformed_element, col_for_color, geom_type
142+
)
131143
reduction_bounds = (agg.min(), agg.max())
132144

133145
nan_elements = transformed_element[transformed_element[col_for_color].isnull()]
@@ -244,7 +256,7 @@ def _ds_shade_categorical(
244256
) -> Any:
245257
"""Shade a categorical or no-color datashader aggregate."""
246258
ds_cmap = None
247-
if color_vector is not None:
259+
if color_key is None and color_vector is not None:
248260
ds_cmap = color_vector[0]
249261
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
250262
ds_cmap = _hex_no_alpha(ds_cmap)

src/spatialdata_plot/pl/render.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ def _reparse_points(
108108
)
109109

110110

111+
def _warn_groups_ignored_continuous(
112+
groups: str | list[str] | None,
113+
color_source_vector: pd.Categorical | None,
114+
col_for_color: str | None,
115+
) -> None:
116+
"""Warn when ``groups`` is set but coloring is continuous (no categorical source)."""
117+
if groups is not None and color_source_vector is None and col_for_color is not None:
118+
logger.warning(
119+
f"`groups` is ignored when coloring by continuous column '{col_for_color}'. "
120+
"`groups` filters categories of the column specified via `color`; "
121+
"it has no effect on continuous data."
122+
)
123+
124+
111125
def _warn_missing_groups(
112126
groups: str | list[str],
113127
color_source_vector: pd.Categorical,
@@ -329,6 +343,8 @@ def _render_shapes(
329343

330344
values_are_categorical = color_source_vector is not None
331345

346+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
347+
332348
if groups is not None and color_source_vector is not None:
333349
_warn_missing_groups(groups, color_source_vector, col_for_color)
334350

@@ -784,6 +800,8 @@ def _render_points(
784800
if added_color_from_table and col_for_color is not None:
785801
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
786802

803+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
804+
787805
if groups is not None and color_source_vector is not None:
788806
_warn_missing_groups(groups, color_source_vector, col_for_color)
789807

@@ -1335,6 +1353,8 @@ def _render_labels(
13351353
else:
13361354
assert color_source_vector is None
13371355

1356+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
1357+
13381358
if groups is not None and color_source_vector is not None:
13391359
_warn_missing_groups(groups, color_source_vector, col_for_color)
13401360

0 commit comments

Comments
 (0)