Skip to content

Commit 19e79ff

Browse files
timtreisclaude
andcommitted
Reduce code bloat: deduplicate helpers, fix latent bug, unify logic
- Extract _inject_ds_nan_sentinel(): fixes latent bug where shapes path did unguarded add_categories() (would crash if sentinel already exists). Both shapes and points now use the same guarded helper. - Extract _want_decorations(): unifies the shapes (3 lines, compared with alpha) and points (17 lines, stripped alpha) logic into one consistent helper that normalizes hex colors before comparing. - Extract _reparse_points(): deduplicates the PointsModel.parse + set_transformation pattern that appeared 3 times. - Replace inline hex-stripping with existing _hex_no_alpha() utility. - Add warning when color vector length mismatches element count (was silent). - Drop unnecessary .copy() on adata boolean indexing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 40206b9 commit 19e79ff

1 file changed

Lines changed: 68 additions & 68 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 68 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,60 @@ def _build_datashader_color_key(
118118
return color_key
119119

120120

121+
def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series:
122+
"""Add a sentinel category for NaN values in a categorical series.
123+
124+
Safely handles series that are not yet categorical, dask-backed
125+
categoricals that need ``as_known()``, and series that already
126+
contain the sentinel.
127+
"""
128+
if not isinstance(series.dtype, pd.CategoricalDtype):
129+
series = series.astype("category")
130+
if hasattr(series.cat, "as_known"):
131+
series = series.cat.as_known()
132+
if sentinel not in series.cat.categories:
133+
series = series.cat.add_categories(sentinel)
134+
return series.fillna(sentinel)
135+
136+
137+
def _want_decorations(color_vector: Any, na_color: Color) -> bool:
138+
"""Return whether legend/colorbar decorations should be shown.
139+
140+
Decorations are suppressed when all colors equal the NA color
141+
(i.e., nothing informative to display).
142+
"""
143+
if color_vector is None:
144+
return False
145+
cv = np.asarray(color_vector)
146+
if cv.size == 0:
147+
return False
148+
unique_vals = set(cv.tolist())
149+
if len(unique_vals) != 1:
150+
return True
151+
only_val = next(iter(unique_vals))
152+
na_hex = na_color.get_hex()
153+
if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"):
154+
return _hex_no_alpha(only_val) != _hex_no_alpha(na_hex)
155+
return bool(only_val != na_hex)
156+
157+
158+
def _reparse_points(
159+
sdata_filt: sd.SpatialData,
160+
element: str,
161+
df: pd.DataFrame,
162+
transformation: Any,
163+
coordinate_system: str,
164+
) -> None:
165+
"""Re-register a points DataFrame in *sdata_filt* with its transformation."""
166+
dd_frame = dask.dataframe.from_pandas(df, npartitions=1)
167+
sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
168+
set_transformation(
169+
element=sdata_filt.points[element],
170+
transformation=transformation,
171+
to_coordinate_system=coordinate_system,
172+
)
173+
174+
121175
def _filter_groups_transparent_na(
122176
groups: str | list[str],
123177
color_source_vector: pd.Categorical,
@@ -364,11 +418,13 @@ def _render_shapes(
364418
# If single color, broadcast to all shapes
365419
color_vector = [color_vector[0]] * len(transformed_element)
366420
else:
367-
# If lengths don't match, pad or truncate to match
421+
logger.warning(
422+
f"Color vector length ({len(color_vector)}) does not match element count "
423+
f"({len(transformed_element)}). This may indicate a bug."
424+
)
368425
if len(color_vector) > len(transformed_element):
369426
color_vector = color_vector[: len(transformed_element)]
370427
else:
371-
# Pad with the last color or na_color
372428
na_color = render_params.cmap_params.na_color.get_hex_with_alpha()
373429
color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector))
374430

@@ -386,9 +442,7 @@ def _render_shapes(
386442
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
387443
if color_by_categorical:
388444
# add a sentinel category so that shapes with NaN value are colored in the na_color
389-
transformed_element[col_for_color] = (
390-
transformed_element[col_for_color].cat.add_categories(_DS_NAN_CATEGORY).fillna(_DS_NAN_CATEGORY)
391-
)
445+
transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color])
392446
agg = cvs.polygons(
393447
transformed_element,
394448
geometry="geometry",
@@ -493,9 +547,7 @@ def _render_shapes(
493547

494548
if continuous_nan_shapes is not None:
495549
# for coloring by continuous variable: render nan shapes separately
496-
nan_color_hex = render_params.cmap_params.na_color.get_hex()
497-
if nan_color_hex.startswith("#") and len(nan_color_hex) == 9:
498-
nan_color_hex = nan_color_hex[:7]
550+
nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
499551
continuous_nan_shapes = ds.tf.shade(
500552
continuous_nan_shapes,
501553
cmap=nan_color_hex,
@@ -664,10 +716,7 @@ def _render_shapes(
664716
vmax = 1.0
665717
_cax.set_clim(vmin=vmin, vmax=vmax)
666718

667-
if (
668-
len(set(color_vector)) != 1
669-
or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha()
670-
):
719+
if _want_decorations(color_vector, render_params.cmap_params.na_color):
671720
# necessary in case different shapes elements are annotated with one table
672721
if color_source_vector is not None and render_params.col_for_color is not None:
673722
color_source_vector = color_source_vector.remove_unused_categories()
@@ -808,14 +857,7 @@ def _render_points(
808857

809858
# Convert back to dask dataframe to modify sdata
810859
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
811-
points_dd = dask.dataframe.from_pandas(points_for_model, npartitions=1)
812-
sdata_filt.points[element] = PointsModel.parse(points_dd, coordinates={"x": "x", "y": "y"})
813-
# restore transformation in coordinate system of interest
814-
set_transformation(
815-
element=sdata_filt.points[element],
816-
transformation=transformation_in_cs,
817-
to_coordinate_system=coordinate_system,
818-
)
860+
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system)
819861

820862
if col_for_color is not None:
821863
assert isinstance(col_for_color, str)
@@ -858,14 +900,7 @@ def _render_points(
858900
)
859901

860902
if added_color_from_table and col_for_color is not None:
861-
points_with_color_dd = dask.dataframe.from_pandas(points_pd_with_color, npartitions=1)
862-
sdata_filt.points[element] = PointsModel.parse(points_with_color_dd, coordinates={"x": "x", "y": "y"})
863-
set_transformation(
864-
element=sdata_filt.points[element],
865-
transformation=transformation_in_cs,
866-
to_coordinate_system=coordinate_system,
867-
)
868-
points_dd = points_with_color_dd
903+
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
869904

870905
# When groups are specified and na_color is fully transparent (na_color=None),
871906
# filter out non-matching points instead of rendering invisible geometry.
@@ -878,14 +913,8 @@ def _render_points(
878913
return
879914
# filter the materialized points, adata, and re-register in sdata_filt
880915
points = points[keep].reset_index(drop=True)
881-
adata = adata[keep].copy()
882-
filtered_dd = dask.dataframe.from_pandas(points, npartitions=1)
883-
sdata_filt.points[element] = PointsModel.parse(filtered_dd, coordinates={"x": "x", "y": "y"})
884-
set_transformation(
885-
element=sdata_filt.points[element],
886-
transformation=transformation_in_cs,
887-
to_coordinate_system=coordinate_system,
888-
)
916+
adata = adata[keep]
917+
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system)
889918

890919
# color_source_vector is None when the values aren't categorical
891920
if color_source_vector is None and render_params.transfunc is not None:
@@ -965,14 +994,7 @@ def _render_points(
965994
if col_for_color is not None:
966995
if color_by_categorical:
967996
# add nan as category so that nan points are shown in the nan color
968-
cat_series = transformed_element[col_for_color]
969-
if not isinstance(cat_series.dtype, pd.CategoricalDtype):
970-
cat_series = cat_series.astype("category")
971-
if hasattr(cat_series.cat, "as_known"):
972-
cat_series = cat_series.cat.as_known()
973-
if _DS_NAN_CATEGORY not in cat_series.cat.categories:
974-
cat_series = cat_series.cat.add_categories(_DS_NAN_CATEGORY)
975-
transformed_element[col_for_color] = cat_series.fillna(_DS_NAN_CATEGORY)
997+
transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color])
976998
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
977999
else:
9781000
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
@@ -1062,9 +1084,7 @@ def _render_points(
10621084

10631085
if continuous_nan_points is not None:
10641086
# for coloring by continuous variable: render nan points separately
1065-
nan_color_hex = render_params.cmap_params.na_color.get_hex()
1066-
if nan_color_hex.startswith("#") and len(nan_color_hex) == 9:
1067-
nan_color_hex = nan_color_hex[:7]
1087+
nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
10681088
continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how="max")
10691089
continuous_nan_points = ds.tf.shade(
10701090
continuous_nan_points,
@@ -1132,27 +1152,7 @@ def _render_points(
11321152
ax.set_xbound(extent["x"])
11331153
ax.set_ybound(extent["y"])
11341154

1135-
# Decide whether there is any informative color variation.
1136-
# We skip legend/colorbar only if all colors are equal to the NA color.
1137-
want_decorations = True
1138-
if color_vector is None:
1139-
want_decorations = False
1140-
else:
1141-
cv = np.asarray(color_vector)
1142-
if cv.size == 0:
1143-
want_decorations = False
1144-
else:
1145-
unique_vals = set(cv.tolist())
1146-
if len(unique_vals) == 1:
1147-
only_val = next(iter(unique_vals))
1148-
na_hex = render_params.cmap_params.na_color.get_hex()
1149-
if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"):
1150-
only_norm = _hex_no_alpha(only_val)
1151-
na_norm = _hex_no_alpha(na_hex)
1152-
if only_norm == na_norm:
1153-
want_decorations = False
1154-
1155-
if want_decorations:
1155+
if _want_decorations(color_vector, render_params.cmap_params.na_color):
11561156
if color_source_vector is None:
11571157
palette = ListedColormap(dict.fromkeys(color_vector))
11581158
else:

0 commit comments

Comments
 (0)