Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
_check_obs_var_shadow,
_convert_shapes,
_datashader_canvas_from_dataframe,
_decorate_axs,
Expand Down Expand Up @@ -370,6 +371,8 @@ def _render_shapes(
groups = render_params.groups
table_layer = render_params.table_layer

_check_obs_var_shadow(sdata, element, col_for_color, render_params.table_name)

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=bool(render_params.table_name),
Expand Down Expand Up @@ -766,6 +769,8 @@ def _render_points(
groups = render_params.groups
palette = render_params.palette

_check_obs_var_shadow(sdata, element, col_for_color, table_name)

if isinstance(groups, str):
groups = [groups]

Expand Down Expand Up @@ -1687,6 +1692,8 @@ def _render_labels(
groups = render_params.groups
scale = render_params.scale

_check_obs_var_shadow(sdata, element, col_for_color, table_name)

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=bool(table_name),
Expand Down
32 changes: 32 additions & 0 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,38 @@
}


def _check_obs_var_shadow(
sdata: SpatialData | None,
element_name: str | None,
value_to_plot: str | None,
table_name: str | None,
) -> None:
"""Raise if ``value_to_plot`` exists in both ``table.obs.columns`` and ``table.var_names``.

Upstream ``_get_table_origins`` uses an ``elif`` chain, so a key that lives in
both locations is silently resolved to ``obs`` — masking the user's likely
intent of plotting gene expression. Catch this here before any value fetch.
Any ``None`` parameter short-circuits the check.
"""
if (
value_to_plot is None
or table_name is None
or element_name is None
or sdata is None
or table_name not in sdata.tables
):
return
if table_name not in get_element_annotators(sdata, element_name):
return
table = sdata.tables[table_name]
if value_to_plot in table.obs.columns and value_to_plot in table.var_names:
raise ValueError(
f"`color={value_to_plot!r}` is ambiguous: it exists in both "
f"`table[{table_name!r}].obs.columns` and `table[{table_name!r}].var_names`. "
"Rename one of them (or drop the obs column) so the intended source is unambiguous."
)


def _gate_palette_and_groups(
element_params: dict[str, Any],
param_dict: dict[str, Any],
Expand Down
16 changes: 16 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ def test_color_column_collision_on_annotating_table_raises():
sdata.pl.render_shapes("s", color="#ffa500")


def test_color_key_obs_var_shadow_raises():
# regression test for #621
pts = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0], "y": [1.0, 2.0]}))
obs = pd.DataFrame({"instance_id": [0, 1], "region": ["pts"] * 2, "GeneA": [0.9, 0.6]}, index=["0", "1"])
table = TableModel.parse(
AnnData(X=np.zeros((2, 1)), obs=obs, var=pd.DataFrame(index=["GeneA"])),
region=["pts"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(points={"pts": pts}, tables={"t": table})

with pytest.raises(ValueError, match=r"'GeneA'.*ambiguous.*obs\.columns.*var_names"):
sdata.pl.render_points("pts", color="GeneA", table_name="t").pl.show()


def test_explicit_table_name_honored_when_element_has_same_column():
# regression test for #620: explicit table_name= must not be silently
# discarded when the element has a same-named column with different values.
Expand Down
Loading