Skip to content

Commit 2f4af2c

Browse files
timtreisclaude
andcommitted
Address review: deduplicate warning, add tests, cover labels
- Extract `_warn_missing_groups` helper and call it from `_filter_groups_transparent_na` (shapes/points) and inline in labels, removing duplicated logic from both call sites. - Distinguish "none matched" (likely wrong column) from "some missing" (likely typo) with different warning messages. - Replace assert with explicit `color_source_vector is not None` guard in the shapes condition, matching the points pattern. - Add 4 tests (shapes + points × all-missing + partial-missing) using `logger_warns` to lock the warning behavior. - Drop `**kwargs` from empty `PatchCollection` fallback. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8717705 commit 2f4af2c

4 files changed

Lines changed: 68 additions & 19 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,42 @@ def _reparse_points(
108108
)
109109

110110

111+
def _warn_missing_groups(
112+
groups: str | list[str],
113+
color_source_vector: pd.Categorical,
114+
col_for_color: str | None = None,
115+
) -> None:
116+
"""Warn when ``groups`` contains values absent from the color column's categories."""
117+
groups_list = [groups] if isinstance(groups, str) else list(groups)
118+
missing = set(groups_list) - set(color_source_vector.categories)
119+
if not missing:
120+
return
121+
col_label = f" '{col_for_color}'" if col_for_color else ""
122+
if missing == set(groups_list):
123+
logger.warning(
124+
f"None of the requested groups {sorted(missing)} were found in column{col_label}. "
125+
"This usually means `groups` refers to values from a different column than `color`. "
126+
"The `groups` parameter selects categories of the column specified via `color`."
127+
)
128+
else:
129+
logger.warning(
130+
f"Groups {sorted(missing)} were not found in column{col_label} and will be ignored. "
131+
f"Available categories: {sorted(color_source_vector.categories)}."
132+
)
133+
134+
111135
def _filter_groups_transparent_na(
112136
groups: str | list[str],
113137
color_source_vector: pd.Categorical,
114138
color_vector: pd.Series | np.ndarray | list[str],
139+
col_for_color: str | None = None,
115140
) -> tuple[np.ndarray, pd.Categorical, np.ndarray]:
116141
"""Return a boolean mask and filtered color vectors for groups filtering.
117142
118143
Used when ``na_color=None`` (fully transparent) so that non-matching
119144
elements are removed entirely instead of rendered invisibly.
120145
"""
146+
_warn_missing_groups(groups, color_source_vector, col_for_color)
121147
keep = color_source_vector.isin(groups)
122148
filtered_csv = color_source_vector[keep]
123149
filtered_cv = np.asarray(color_vector)[keep]
@@ -301,17 +327,9 @@ def _render_shapes(
301327
# When groups are specified, filter out non-matching elements by default.
302328
# Only show non-matching elements if the user explicitly sets na_color.
303329
_na = render_params.cmap_params.na_color
304-
if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"):
305-
assert color_source_vector is not None # guaranteed by values_are_categorical
306-
_groups_list = [groups] if isinstance(groups, str) else groups
307-
_missing = set(_groups_list) - set(color_source_vector.categories)
308-
if _missing:
309-
logger.warning(
310-
f"Groups {sorted(_missing)} not found in the values of '{col_for_color}'. "
311-
"The `groups` parameter filters values of the `color` column."
312-
)
330+
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
313331
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
314-
groups, color_source_vector, color_vector
332+
groups, color_source_vector, color_vector, col_for_color=col_for_color
315333
)
316334
shapes = shapes[keep].reset_index(drop=True)
317335
if len(shapes) == 0:
@@ -762,15 +780,8 @@ def _render_points(
762780
# Only show non-matching elements if the user explicitly sets na_color.
763781
_na = render_params.cmap_params.na_color
764782
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
765-
_groups_list = [groups] if isinstance(groups, str) else groups
766-
_missing = set(_groups_list) - set(color_source_vector.categories)
767-
if _missing:
768-
logger.warning(
769-
f"Groups {sorted(_missing)} not found in the values of '{col_for_color}'. "
770-
"The `groups` parameter filters values of the `color` column."
771-
)
772783
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
773-
groups, color_source_vector, color_vector
784+
groups, color_source_vector, color_vector, col_for_color=col_for_color
774785
)
775786
n_points = int(keep.sum())
776787
if n_points == 0:
@@ -1322,6 +1333,7 @@ def _render_labels(
13221333
and color_source_vector is not None
13231334
and (_na.default_color_set or _na.alpha == "00")
13241335
):
1336+
_warn_missing_groups(groups, color_source_vector, col_for_color)
13251337
keep_vec = color_source_vector.isin(groups)
13261338
matching_ids = instance_id[keep_vec]
13271339
keep_mask = np.isin(label.values, matching_ids)

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def _create_patches(
554554
)
555555

556556
if patches.empty:
557-
return PatchCollection([], **kwargs)
557+
return PatchCollection([])
558558

559559
return PatchCollection(
560560
patches["geometry"].values.tolist(),

tests/pl/test_render_points.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from spatialdata.transformations._utils import _set_transformations
2323

2424
import spatialdata_plot # noqa: F401
25+
from spatialdata_plot._logging import logger, logger_warns
2526
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG
2627

2728
sc.pl.set_rcParams_defaults()
@@ -607,6 +608,24 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
607608
).pl.show()
608609

609610

611+
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog):
612+
"""When none of the groups match color categories, a warning should be emitted."""
613+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
614+
with logger_warns(caplog, logger, match="None of the requested groups"):
615+
sdata_blobs.pl.render_points(
616+
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=None, size=30
617+
).pl.show()
618+
619+
620+
def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog):
621+
"""When some groups match but others don't, a warning should list the missing ones."""
622+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
623+
with logger_warns(caplog, logger, match="were not found in column"):
624+
sdata_blobs.pl.render_points(
625+
"blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=None, size=30
626+
).pl.show()
627+
628+
610629
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
611630
# Work on an independent copy since we mutate tables
612631
sdata_blobs_local = deepcopy(sdata_blobs)

tests/pl/test_render_shapes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,24 @@ def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
10161016
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show()
10171017

10181018

1019+
def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog):
1020+
"""When none of the groups match color categories, a warning should be emitted."""
1021+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
1022+
with logger_warns(caplog, logger, match="None of the requested groups"):
1023+
sdata_blobs.pl.render_shapes(
1024+
"blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None
1025+
).pl.show()
1026+
1027+
1028+
def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog):
1029+
"""When some groups match but others don't, a warning should list the missing ones."""
1030+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
1031+
with logger_warns(caplog, logger, match="were not found in column"):
1032+
sdata_blobs.pl.render_shapes(
1033+
"blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=None
1034+
).pl.show()
1035+
1036+
10191037
def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
10201038
"""Test that NaN values in color data are handled gracefully and logged."""
10211039
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)

0 commit comments

Comments
 (0)