Skip to content

Commit 44ffab6

Browse files
timtreisclaude
andcommitted
Fix NaN sentinel, groups filtering, and code quality issues
- Replace hardcoded "nan" sentinel with _DS_NAN_CATEGORY = "ds_nan" to avoid collision with user category names - Fix groups filtering crash: wrap pd.Categorical in pd.Series before calling .reset_index() (shapes and points paths) - Add groups + na_color=None filtering: when na_color is fully transparent, non-matching elements are removed instead of rendered invisibly (shapes and points) - Extract _build_datashader_color_key() to deduplicate identical color_key dict building loops in shapes and points - Extract _build_alignment_dtype_hint() from inline 57-line diagnostic block in _set_color_source_vec - Fix "cola" typo in shape NaN test (should be "col_a") - Remove dead duplicate instance_id assignments in test fixtures - Update groups docstring to document na_color=None filtering behavior - Add tests for groups + na_color=None filtering path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 86e9b53 commit 44ffab6

6 files changed

Lines changed: 120 additions & 81 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ def render_shapes(
213213
`fill_alpha` will overwrite the value present in the cmap.
214214
groups : list[str] | str | None
215215
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
216-
them. Other values are set to NA. If elment is None, broadcasting behaviour is attempted (use the same
217-
values for all elements).
216+
them. Other values are set to NA. When ``na_color=None``, non-matching elements are filtered out entirely
217+
(shapes and points only). If element is None, broadcasting behaviour is attempted (use the same values for
218+
all elements).
218219
palette : list[str] | str | None
219220
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
220221
match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for

src/spatialdata_plot/pl/render.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@
6464

6565
_Normalize = Normalize | abc.Sequence[Normalize]
6666

67+
# Sentinel category name used in datashader categorical paths to represent
68+
# missing (NaN) values. Must not collide with realistic user category names.
69+
_DS_NAN_CATEGORY = "ds_nan"
70+
6771

6872
def _coerce_categorical_source(cat_source: Any) -> pd.Categorical:
6973
"""Return a pandas Categorical from known, concrete sources only."""
@@ -82,6 +86,26 @@ def _coerce_categorical_source(cat_source: Any) -> pd.Categorical:
8286
return pd.Categorical(pd.Series(cat_source))
8387

8488

89+
def _build_datashader_color_key(
90+
cat_series: pd.Categorical,
91+
color_vector: Any,
92+
na_color_hex: str,
93+
) -> dict[str, str]:
94+
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
95+
colors_arr = np.asarray(color_vector, dtype=object)
96+
color_key: dict[str, str] = {}
97+
for cat in cat_series.categories:
98+
if cat == _DS_NAN_CATEGORY:
99+
key_color = na_color_hex
100+
else:
101+
idx = np.flatnonzero(cat_series == cat)
102+
key_color = colors_arr[idx[0]] if idx.size else na_color_hex
103+
if isinstance(key_color, str) and key_color.startswith("#"):
104+
key_color = _hex_no_alpha(key_color)
105+
color_key[str(cat)] = key_color
106+
return color_key
107+
108+
85109
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
86110
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
87111
layout: dict[str, object] = {}
@@ -185,6 +209,20 @@ def _render_shapes(
185209

186210
values_are_categorical = color_source_vector is not None
187211

212+
# When groups are specified and na_color is fully transparent (na_color=None),
213+
# filter out non-matching elements instead of showing them as invisible geometry.
214+
if groups is not None and values_are_categorical and render_params.cmap_params.na_color.alpha == "00":
215+
csv_series = pd.Series(color_source_vector)
216+
keep = csv_series.isin(groups).values
217+
shapes = shapes[keep].reset_index(drop=True)
218+
sdata_filt[element] = shapes
219+
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
220+
color_vector = (
221+
np.asarray(color_vector)[keep]
222+
if not hasattr(color_vector, "reset_index")
223+
else (color_vector[keep].reset_index(drop=True))
224+
)
225+
188226
# color_source_vector is None when the values aren't categorical
189227
if values_are_categorical and render_params.transfunc is not None:
190228
color_vector = render_params.transfunc(color_vector)
@@ -322,9 +360,9 @@ def _render_shapes(
322360
continuous_nan_shapes = None
323361
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
324362
if color_by_categorical:
325-
# add nan as a category so that shapes with nan value are colored in the nan color
363+
# add a sentinel category so that shapes with NaN value are colored in the na_color
326364
transformed_element[col_for_color] = (
327-
transformed_element[col_for_color].cat.add_categories("nan").fillna("nan")
365+
transformed_element[col_for_color].cat.add_categories(_DS_NAN_CATEGORY).fillna(_DS_NAN_CATEGORY)
328366
)
329367
agg = cvs.polygons(
330368
transformed_element,
@@ -391,17 +429,9 @@ def _render_shapes(
391429
color_key: dict[str, str] | None = None
392430
if color_by_categorical and col_for_color is not None:
393431
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
394-
colors_arr = np.asarray(color_vector, dtype=object)
395-
color_key = {}
396-
for cat in cat_series.categories:
397-
if cat == "nan":
398-
key_color = render_params.cmap_params.na_color.get_hex()
399-
else:
400-
idx = np.flatnonzero(cat_series == cat)
401-
key_color = colors_arr[idx[0]] if idx.size else render_params.cmap_params.na_color.get_hex()
402-
if isinstance(key_color, str) and key_color.startswith("#"):
403-
key_color = _hex_no_alpha(key_color)
404-
color_key[str(cat)] = key_color
432+
color_key = _build_datashader_color_key(
433+
cat_series, color_vector, render_params.cmap_params.na_color.get_hex()
434+
)
405435

406436
if color_by_categorical or col_for_color is None:
407437
ds_cmap = None
@@ -812,6 +842,27 @@ def _render_points(
812842
)
813843
points_dd = points_with_color_dd
814844

845+
# When groups are specified and na_color is fully transparent (na_color=None),
846+
# filter out non-matching points instead of rendering invisible geometry.
847+
if groups is not None and color_source_vector is not None and render_params.cmap_params.na_color.alpha == "00":
848+
csv_series = pd.Series(color_source_vector)
849+
keep = csv_series.isin(groups).values
850+
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
851+
color_vector = (
852+
np.asarray(color_vector)[keep]
853+
if not hasattr(color_vector, "reset_index")
854+
else (color_vector[keep].reset_index(drop=True))
855+
)
856+
# re-register filtered points in sdata_filt
857+
points_dd = dask.dataframe.from_pandas(points[keep].reset_index(drop=True), npartitions=1)
858+
sdata_filt.points[element] = PointsModel.parse(points_dd, coordinates={"x": "x", "y": "y"})
859+
set_transformation(
860+
element=sdata_filt.points[element],
861+
transformation=transformation_in_cs,
862+
to_coordinate_system=coordinate_system,
863+
)
864+
n_points = int(keep.sum())
865+
815866
# color_source_vector is None when the values aren't categorical
816867
if color_source_vector is None and render_params.transfunc is not None:
817868
color_vector = render_params.transfunc(color_vector)
@@ -895,9 +946,9 @@ def _render_points(
895946
cat_series = cat_series.astype("category")
896947
if hasattr(cat_series.cat, "as_known"):
897948
cat_series = cat_series.cat.as_known()
898-
if "nan" not in cat_series.cat.categories:
899-
cat_series = cat_series.cat.add_categories("nan")
900-
transformed_element[col_for_color] = cat_series.fillna("nan")
949+
if _DS_NAN_CATEGORY not in cat_series.cat.categories:
950+
cat_series = cat_series.cat.add_categories(_DS_NAN_CATEGORY)
951+
transformed_element[col_for_color] = cat_series.fillna(_DS_NAN_CATEGORY)
901952
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
902953
else:
903954
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
@@ -942,17 +993,9 @@ def _render_points(
942993
color_key: dict[str, str] | None = None
943994
if color_by_categorical and col_for_color is not None:
944995
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
945-
colors_arr = np.asarray(color_vector, dtype=object)
946-
color_key = {}
947-
for cat in cat_series.categories:
948-
if cat == "nan":
949-
key_color = render_params.cmap_params.na_color.get_hex()
950-
else:
951-
idx = np.flatnonzero(cat_series == cat)
952-
key_color = colors_arr[idx[0]] if idx.size else render_params.cmap_params.na_color.get_hex()
953-
if isinstance(key_color, str) and key_color.startswith("#"):
954-
key_color = _hex_no_alpha(key_color)
955-
color_key[str(cat)] = key_color
996+
color_key = _build_datashader_color_key(
997+
cat_series, color_vector, render_params.cmap_params.na_color.get_hex()
998+
)
956999

9571000
if (
9581001
color_vector is not None

src/spatialdata_plot/pl/utils.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,40 @@ def _infer_color_data_kind(
983983
return "numeric", pd.to_numeric(series, errors="coerce")
984984

985985

986+
def _build_alignment_dtype_hint(
987+
sdata: sd.SpatialData | None,
988+
element: object,
989+
color_series: pd.Series,
990+
table_name: str | None,
991+
) -> str:
992+
"""Build a diagnostic hint string for dtype mismatches between element and table indices."""
993+
hints: list[str] = []
994+
color_index_dtype = getattr(color_series.index, "dtype", None)
995+
element_index_dtype = getattr(getattr(element, "index", None), "dtype", None) if element is not None else None
996+
997+
table_instance_dtype = None
998+
instance_key = None
999+
if table_name is not None and sdata is not None and table_name in sdata.tables:
1000+
table = sdata.tables[table_name]
1001+
try:
1002+
_, _, instance_key = get_table_keys(table)
1003+
except (KeyError, ValueError, TypeError, AttributeError):
1004+
instance_key = None
1005+
if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs:
1006+
table_instance_dtype = table.obs[instance_key].dtype
1007+
1008+
if (
1009+
element_index_dtype is not None
1010+
and table_instance_dtype is not None
1011+
and element_index_dtype != table_instance_dtype
1012+
):
1013+
hints.append(f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}")
1014+
if color_index_dtype is not None and element_index_dtype is not None and color_index_dtype != element_index_dtype:
1015+
hints.append(f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}")
1016+
1017+
return f" (hint: {'; '.join(hints)})" if hints else ""
1018+
1019+
9861020
def _set_color_source_vec(
9871021
sdata: sd.SpatialData,
9881022
element: SpatialElement | None,
@@ -1039,54 +1073,7 @@ def _set_color_source_vec(
10391073
if color_series.isna().all():
10401074
element_label = _format_element_name(element_name)
10411075
location = f"table '{table_name}'" if table_name is not None else "the element"
1042-
# Provide dtype hints to help diagnose index alignment issues
1043-
dtype_hints: list[str] = []
1044-
color_index_dtype = getattr(color_series.index, "dtype", None)
1045-
element_index_dtype = (
1046-
getattr(getattr(element, "index", None), "dtype", None) if element is not None else None
1047-
)
1048-
1049-
table_instance_dtype = None
1050-
table_index_dtype = None
1051-
instance_key = None
1052-
if table_name is not None and sdata is not None and table_name in sdata.tables:
1053-
table = sdata.tables[table_name]
1054-
table_index_dtype = getattr(getattr(table, "obs", None), "index", None)
1055-
if table_index_dtype is not None:
1056-
table_index_dtype = getattr(table_index_dtype, "dtype", None)
1057-
try:
1058-
_, _, instance_key = get_table_keys(table)
1059-
except (KeyError, ValueError, TypeError, AttributeError):
1060-
instance_key = None
1061-
if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs:
1062-
table_instance_dtype = table.obs[instance_key].dtype
1063-
1064-
if (
1065-
element_index_dtype is not None
1066-
and table_instance_dtype is not None
1067-
and element_index_dtype != table_instance_dtype
1068-
):
1069-
dtype_hints.append(
1070-
f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}"
1071-
)
1072-
if (
1073-
table_index_dtype is not None
1074-
and table_instance_dtype is not None
1075-
and table_index_dtype != table_instance_dtype
1076-
):
1077-
dtype_hints.append(
1078-
f"table index dtype is {table_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}"
1079-
)
1080-
if (
1081-
color_index_dtype is not None
1082-
and element_index_dtype is not None
1083-
and color_index_dtype != element_index_dtype
1084-
):
1085-
dtype_hints.append(
1086-
f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}"
1087-
)
1088-
1089-
dtype_hint = f" (hint: {'; '.join(dtype_hints)})" if dtype_hints else ""
1076+
dtype_hint = _build_alignment_dtype_hint(sdata, element, color_series, table_name)
10901077
raise ValueError(
10911078
f"Column '{value_to_plot}' for element '{element_label}' contains only missing values after aligning "
10921079
f"with {location}. This usually means the instance ids/indices could not be aligned or converted, so "

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def sdata_blobs_points_with_nans_in_table() -> SpatialData:
200200
adata.var = pd.DataFrame({}, index=["col1", "col2"])
201201
adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["col_a", "col_b", "col_c"])
202202
adata.obs.iloc[0:30, adata.obs.columns.get_loc("col_a")] = np.nan
203-
adata.obs["instance_id"] = np.arange(adata.n_obs)
204203
cat_pattern = ["a", "b", np.nan]
205204
repeats = (n_obs + len(cat_pattern) - 1) // len(cat_pattern)
206205
adata.obs["category"] = pd.Categorical((cat_pattern * repeats)[:n_obs])
@@ -221,7 +220,6 @@ def sdata_blobs_shapes_with_nans_in_table() -> SpatialData:
221220
adata.var = pd.DataFrame({}, index=["col1", "col2"])
222221
adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["col_a", "col_b", "col_c"])
223222
adata.obs.iloc[0, adata.obs.columns.get_loc("col_a")] = np.nan
224-
adata.obs["instance_id"] = np.arange(adata.n_obs)
225223
cat_pattern = ["a", "b", np.nan, "c", "a"]
226224
repeats = (n_obs + len(cat_pattern) - 1) // len(cat_pattern)
227225
adata.obs["category"] = pd.Categorical((cat_pattern * repeats)[:n_obs])

tests/pl/test_render_points.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,11 @@ def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sda
576576
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
577577
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show()
578578

579+
def test_plot_groups_na_color_none_filters_points(self, sdata_blobs: SpatialData):
580+
"""When groups is set and na_color=None, non-matching points are filtered out entirely."""
581+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
582+
sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], na_color=None, size=30).pl.show()
583+
579584

580585
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
581586
# Work on an independent copy since we mutate tables

tests/pl/test_render_shapes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,13 +946,13 @@ def test_plot_can_annotate_shapes_with_nan_in_table_obs_categorical_datashader(
946946
def test_plot_can_annotate_shapes_with_nan_in_table_obs_continuous(
947947
self, sdata_blobs_shapes_with_nans_in_table: SpatialData
948948
):
949-
sdata_blobs_shapes_with_nans_in_table.pl.render_shapes("blobs_polygons", color="cola").pl.show()
949+
sdata_blobs_shapes_with_nans_in_table.pl.render_shapes("blobs_polygons", color="col_a").pl.show()
950950

951951
def test_plot_can_annotate_shapes_with_nan_in_table_obs_continuous_datashader(
952952
self, sdata_blobs_shapes_with_nans_in_table: SpatialData
953953
):
954954
sdata_blobs_shapes_with_nans_in_table.pl.render_shapes(
955-
"blobs_polygons", color="cola", method="datashader"
955+
"blobs_polygons", color="col_a", method="datashader"
956956
).pl.show()
957957

958958
def test_plot_can_annotate_shapes_with_nan_in_table_X_continuous(
@@ -983,6 +983,11 @@ def test_plot_can_annotate_shapes_with_nan_in_df_continuous_datashader(self, sda
983983
sdata_blobs["blobs_polygons"]["cont_color"] = [np.nan, 2, 3, 4, 5]
984984
sdata_blobs.pl.render_shapes("blobs_polygons", color="cont_color", method="datashader").pl.show()
985985

986+
def test_plot_groups_na_color_none_filters_shapes(self, sdata_blobs: SpatialData):
987+
"""When groups is set and na_color=None, non-matching shapes are filtered out entirely."""
988+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
989+
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], na_color=None).pl.show()
990+
986991

987992
def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
988993
"""Test that NaN values in color data are handled gracefully and logged."""

0 commit comments

Comments
 (0)