Skip to content

Commit f4dd082

Browse files
authored
fix: surface gene_symbols column typo on auto-detect path (#658)
Signed-off-by: Sai Asish Y <say.apm35@gmail.com>
1 parent 53abe71 commit f4dd082

2 files changed

Lines changed: 18 additions & 0 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3161,6 +3161,13 @@ def _validate_col_for_column_table(
31613161
)
31623162
# Now check which tables contain the column
31633163
resolved_var_name: str | None = None
3164+
if gene_symbols is not None and not any(gene_symbols in sdata[t].var.columns for t in tables):
3165+
available = sorted({c for t in tables for c in sdata[t].var.columns})
3166+
raise KeyError(
3167+
f"Column '{gene_symbols}' specified in `gene_symbols=` was not found in "
3168+
f"`adata.var` of any table annotating element '{element_name}'. "
3169+
f"Available var columns: {available}"
3170+
)
31643171
for annotates in tables.copy():
31653172
if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names:
31663173
if gene_symbols is not None:

tests/pl/test_render_shapes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,17 @@ def test_gene_symbols_missing_column_raises(sdata_blobs: SpatialData):
10531053
).pl.show()
10541054

10551055

1056+
def test_gene_symbols_missing_column_raises_auto_detect(sdata_blobs: SpatialData):
1057+
"""Typo in gene_symbols= must surface on the auto-detect path, not be swallowed."""
1058+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
1059+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
1060+
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
1061+
with pytest.raises(KeyError, match="`gene_symbols=`"):
1062+
sdata_blobs.pl.render_shapes(
1063+
"blobs_circles", color="GeneA", gene_symbols="WRONGCOL"
1064+
).pl.show()
1065+
1066+
10561067
def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
10571068
"""When no elements match the groups, the plot should render without error."""
10581069
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")

0 commit comments

Comments
 (0)