Skip to content

Commit f15f0e0

Browse files
timtreisclaude
andcommitted
Simplify: decouple warning from filter, fix edge cases
- Move _warn_missing_groups call before the na_color guard so the warning fires regardless of na_color (not just when na_color is transparent). - Remove col_for_color param from _filter_groups_transparent_na (warning is no longer its responsibility). - Fix fallback column label: "in column." → "in the color column." - Guard sorted() with try/except for non-sortable category types. - Deduplicate set(groups_list) into a single groups_set variable. - Parametrize tests over na_color=[None, "red"] to cover both paths. - Add labels warning test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2f4af2c commit f15f0e0

4 files changed

Lines changed: 62 additions & 26 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,36 +114,41 @@ def _warn_missing_groups(
114114
col_for_color: str | None = None,
115115
) -> None:
116116
"""Warn when ``groups`` contains values absent from the color column's categories."""
117-
groups_list = [groups] if isinstance(groups, str) else list(groups)
118-
missing = set(groups_list) - set(color_source_vector.categories)
117+
groups_set = {groups} if isinstance(groups, str) else set(groups)
118+
missing = groups_set - set(color_source_vector.categories)
119119
if not missing:
120120
return
121-
col_label = f" '{col_for_color}'" if col_for_color else ""
122-
if missing == set(groups_list):
121+
col_label = f" '{col_for_color}'" if col_for_color else " the color column"
122+
try:
123+
missing_str = str(sorted(missing))
124+
except TypeError:
125+
missing_str = str(list(missing))
126+
if missing == groups_set:
123127
logger.warning(
124-
f"None of the requested groups {sorted(missing)} were found in column{col_label}. "
128+
f"None of the requested groups {missing_str} were found in{col_label}. "
125129
"This usually means `groups` refers to values from a different column than `color`. "
126130
"The `groups` parameter selects categories of the column specified via `color`."
127131
)
128132
else:
133+
try:
134+
cats_str = str(sorted(color_source_vector.categories))
135+
except TypeError:
136+
cats_str = str(list(color_source_vector.categories))
129137
logger.warning(
130-
f"Groups {sorted(missing)} were not found in column{col_label} and will be ignored. "
131-
f"Available categories: {sorted(color_source_vector.categories)}."
138+
f"Groups {missing_str} were not found in{col_label} and will be ignored. Available categories: {cats_str}."
132139
)
133140

134141

135142
def _filter_groups_transparent_na(
136143
groups: str | list[str],
137144
color_source_vector: pd.Categorical,
138145
color_vector: pd.Series | np.ndarray | list[str],
139-
col_for_color: str | None = None,
140146
) -> tuple[np.ndarray, pd.Categorical, np.ndarray]:
141147
"""Return a boolean mask and filtered color vectors for groups filtering.
142148
143149
Used when ``na_color=None`` (fully transparent) so that non-matching
144150
elements are removed entirely instead of rendered invisibly.
145151
"""
146-
_warn_missing_groups(groups, color_source_vector, col_for_color)
147152
keep = color_source_vector.isin(groups)
148153
filtered_csv = color_source_vector[keep]
149154
filtered_cv = np.asarray(color_vector)[keep]
@@ -324,12 +329,15 @@ def _render_shapes(
324329

325330
values_are_categorical = color_source_vector is not None
326331

332+
if groups is not None and color_source_vector is not None:
333+
_warn_missing_groups(groups, color_source_vector, col_for_color)
334+
327335
# When groups are specified, filter out non-matching elements by default.
328336
# Only show non-matching elements if the user explicitly sets na_color.
329337
_na = render_params.cmap_params.na_color
330338
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
331339
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
332-
groups, color_source_vector, color_vector, col_for_color=col_for_color
340+
groups, color_source_vector, color_vector
333341
)
334342
shapes = shapes[keep].reset_index(drop=True)
335343
if len(shapes) == 0:
@@ -776,12 +784,15 @@ def _render_points(
776784
if added_color_from_table and col_for_color is not None:
777785
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
778786

787+
if groups is not None and color_source_vector is not None:
788+
_warn_missing_groups(groups, color_source_vector, col_for_color)
789+
779790
# When groups are specified, filter out non-matching elements by default.
780791
# Only show non-matching elements if the user explicitly sets na_color.
781792
_na = render_params.cmap_params.na_color
782793
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
783794
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
784-
groups, color_source_vector, color_vector, col_for_color=col_for_color
795+
groups, color_source_vector, color_vector
785796
)
786797
n_points = int(keep.sum())
787798
if n_points == 0:
@@ -1324,6 +1335,9 @@ def _render_labels(
13241335
else:
13251336
assert color_source_vector is None
13261337

1338+
if groups is not None and color_source_vector is not None:
1339+
_warn_missing_groups(groups, color_source_vector, col_for_color)
1340+
13271341
# When groups are specified, zero out non-matching label IDs so they render as background.
13281342
# Only show non-matching labels if the user explicitly sets na_color.
13291343
_na = render_params.cmap_params.na_color
@@ -1333,7 +1347,6 @@ def _render_labels(
13331347
and color_source_vector is not None
13341348
and (_na.default_color_set or _na.alpha == "00")
13351349
):
1336-
_warn_missing_groups(groups, color_source_vector, col_for_color)
13371350
keep_vec = color_source_vector.isin(groups)
13381351
matching_ids = instance_id[keep_vec]
13391352
keep_mask = np.isin(label.values, matching_ids)

tests/pl/test_render_labels.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from spatialdata.models import Labels2DModel, TableModel
1313

1414
import spatialdata_plot # noqa: F401
15+
from spatialdata_plot._logging import logger, logger_warns
1516
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG
1617

1718
sc.pl.set_rcParams_defaults()
@@ -428,3 +429,21 @@ def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
428429
color="channel_0_sum",
429430
table_name="other_table",
430431
).pl.show()
432+
433+
434+
def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, caplog):
435+
"""Warning fires when no groups match label color categories."""
436+
labels_name = "blobs_labels"
437+
instances = get_element_instances(sdata_blobs[labels_name])
438+
n_obs = len(instances)
439+
adata = AnnData(np.zeros((n_obs, 1)))
440+
adata.obs["instance_id"] = instances.values
441+
adata.obs["cat"] = pd.Categorical(["a", "b"] * (n_obs // 2) + ["a"] * (n_obs % 2))
442+
adata.obs["region"] = labels_name
443+
sdata_blobs["label_table"] = TableModel.parse(
444+
adata=adata, region_key="region", instance_key="instance_id", region=labels_name
445+
)
446+
with logger_warns(caplog, logger, match="None of the requested groups"):
447+
sdata_blobs.pl.render_labels(
448+
labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None
449+
).pl.show()

tests/pl/test_render_points.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -608,21 +608,23 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
608608
).pl.show()
609609

610610

611-
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog):
612-
"""When none of the groups match color categories, a warning should be emitted."""
611+
@pytest.mark.parametrize("na_color", [None, "red"])
612+
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color):
613+
"""Warning fires regardless of na_color when no groups match."""
613614
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
614615
with logger_warns(caplog, logger, match="None of the requested groups"):
615616
sdata_blobs.pl.render_points(
616-
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=None, size=30
617+
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=na_color, size=30
617618
).pl.show()
618619

619620

620-
def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog):
621-
"""When some groups match but others don't, a warning should list the missing ones."""
621+
@pytest.mark.parametrize("na_color", [None, "red"])
622+
def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog, na_color):
623+
"""Warning fires regardless of na_color when some groups are missing."""
622624
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
623-
with logger_warns(caplog, logger, match="were not found in column"):
625+
with logger_warns(caplog, logger, match="were not found in"):
624626
sdata_blobs.pl.render_points(
625-
"blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=None, size=30
627+
"blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=na_color, size=30
626628
).pl.show()
627629

628630

tests/pl/test_render_shapes.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,21 +1016,23 @@ def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
10161016
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show()
10171017

10181018

1019-
def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog):
1020-
"""When none of the groups match color categories, a warning should be emitted."""
1019+
@pytest.mark.parametrize("na_color", [None, "red"])
1020+
def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog, na_color):
1021+
"""Warning fires regardless of na_color when no groups match."""
10211022
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
10221023
with logger_warns(caplog, logger, match="None of the requested groups"):
10231024
sdata_blobs.pl.render_shapes(
1024-
"blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None
1025+
"blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=na_color
10251026
).pl.show()
10261027

10271028

1028-
def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog):
1029-
"""When some groups match but others don't, a warning should list the missing ones."""
1029+
@pytest.mark.parametrize("na_color", [None, "red"])
1030+
def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog, na_color):
1031+
"""Warning fires regardless of na_color when some groups are missing."""
10301032
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
1031-
with logger_warns(caplog, logger, match="were not found in column"):
1033+
with logger_warns(caplog, logger, match="were not found in"):
10321034
sdata_blobs.pl.render_shapes(
1033-
"blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=None
1035+
"blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=na_color
10341036
).pl.show()
10351037

10361038

0 commit comments

Comments
 (0)