Skip to content

Commit 75cadce

Browse files
authored
Raise on obs/var key collision in table-based coloring (#678)
1 parent 3ebefe1 commit 75cadce

3 files changed

Lines changed: 55 additions & 0 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from spatialdata_plot.pl.utils import (
6060
_ax_show_and_transform,
61+
_check_obs_var_shadow,
6162
_convert_shapes,
6263
_datashader_canvas_from_dataframe,
6364
_decorate_axs,
@@ -370,6 +371,8 @@ def _render_shapes(
370371
groups = render_params.groups
371372
table_layer = render_params.table_layer
372373

374+
_check_obs_var_shadow(sdata, element, col_for_color, render_params.table_name)
375+
373376
sdata_filt = sdata.filter_by_coordinate_system(
374377
coordinate_system=coordinate_system,
375378
filter_tables=bool(render_params.table_name),
@@ -766,6 +769,8 @@ def _render_points(
766769
groups = render_params.groups
767770
palette = render_params.palette
768771

772+
_check_obs_var_shadow(sdata, element, col_for_color, table_name)
773+
769774
if isinstance(groups, str):
770775
groups = [groups]
771776

@@ -1687,6 +1692,8 @@ def _render_labels(
16871692
groups = render_params.groups
16881693
scale = render_params.scale
16891694

1695+
_check_obs_var_shadow(sdata, element, col_for_color, table_name)
1696+
16901697
sdata_filt = sdata.filter_by_coordinate_system(
16911698
coordinate_system=coordinate_system,
16921699
filter_tables=bool(table_name),

src/spatialdata_plot/pl/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,38 @@
104104
}
105105

106106

107+
def _check_obs_var_shadow(
108+
sdata: SpatialData | None,
109+
element_name: str | None,
110+
value_to_plot: str | None,
111+
table_name: str | None,
112+
) -> None:
113+
"""Raise if ``value_to_plot`` exists in both ``table.obs.columns`` and ``table.var_names``.
114+
115+
Upstream ``_get_table_origins`` uses an ``elif`` chain, so a key that lives in
116+
both locations is silently resolved to ``obs`` — masking the user's likely
117+
intent of plotting gene expression. Catch this here before any value fetch.
118+
Any ``None`` parameter short-circuits the check.
119+
"""
120+
if (
121+
value_to_plot is None
122+
or table_name is None
123+
or element_name is None
124+
or sdata is None
125+
or table_name not in sdata.tables
126+
):
127+
return
128+
if table_name not in get_element_annotators(sdata, element_name):
129+
return
130+
table = sdata.tables[table_name]
131+
if value_to_plot in table.obs.columns and value_to_plot in table.var_names:
132+
raise ValueError(
133+
f"`color={value_to_plot!r}` is ambiguous: it exists in both "
134+
f"`table[{table_name!r}].obs.columns` and `table[{table_name!r}].var_names`. "
135+
"Rename one of them (or drop the obs column) so the intended source is unambiguous."
136+
)
137+
138+
107139
def _gate_palette_and_groups(
108140
element_params: dict[str, Any],
109141
param_dict: dict[str, Any],

tests/pl/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,22 @@ def test_color_column_collision_on_annotating_table_raises():
422422
sdata.pl.render_shapes("s", color="#ffa500")
423423

424424

425+
def test_color_key_obs_var_shadow_raises():
426+
# regression test for #621
427+
pts = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0], "y": [1.0, 2.0]}))
428+
obs = pd.DataFrame({"instance_id": [0, 1], "region": ["pts"] * 2, "GeneA": [0.9, 0.6]}, index=["0", "1"])
429+
table = TableModel.parse(
430+
AnnData(X=np.zeros((2, 1)), obs=obs, var=pd.DataFrame(index=["GeneA"])),
431+
region=["pts"],
432+
region_key="region",
433+
instance_key="instance_id",
434+
)
435+
sdata = SpatialData(points={"pts": pts}, tables={"t": table})
436+
437+
with pytest.raises(ValueError, match=r"'GeneA'.*ambiguous.*obs\.columns.*var_names"):
438+
sdata.pl.render_points("pts", color="GeneA", table_name="t").pl.show()
439+
440+
425441
def test_explicit_table_name_honored_when_element_has_same_column():
426442
# regression test for #620: explicit table_name= must not be silently
427443
# discarded when the element has a same-named column with different values.

0 commit comments

Comments
 (0)