Skip to content

Commit acc4c0d

Browse files
committed
Honor explicit table_name= when element has same-named column (#620)
When `table_name=` was passed to a render function but the element already had a column with the same name, the explicit `table_name` was silently discarded and the element column was used instead. Fix `_validate_col_for_column_table` to only fall back to the element column when no explicit table was requested, and update `_set_color_source_vec` to drop the df origin (and bypass the upstream multi-origin check in `get_values`) when the user has explicitly chosen a table.
1 parent 13c3cde commit acc4c0d

2 files changed

Lines changed: 56 additions & 3 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,12 @@ def _set_color_source_vec(
10781078
table_name=table_name,
10791079
)
10801080

1081+
# When `table_name=` is explicit, an element column with the same name is
1082+
# shadowed by that choice (#620); drop the df origin so the table wins.
1083+
explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins)
1084+
if explicit_table_shadows_df:
1085+
origins = [o for o in origins if o.origin != "df"]
1086+
10811087
if len(origins) > 1:
10821088
raise ValueError(
10831089
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
@@ -1094,6 +1100,15 @@ def _set_color_source_vec(
10941100
)
10951101
if preloaded_color_data is not None:
10961102
color_source_vector = preloaded_color_data
1103+
elif explicit_table_shadows_df:
1104+
# Pass the table as `element` so upstream `get_values` skips the
1105+
# element-column lookup and avoids the multi-origin error.
1106+
color_source_vector = get_values(
1107+
value_key=value_to_plot,
1108+
element=sdata[table_name],
1109+
element_name=element_name,
1110+
table_layer=table_layer,
1111+
)[value_to_plot]
10971112
else:
10981113
color_source_vector = get_values(
10991114
value_key=value_to_plot,
@@ -3170,9 +3185,9 @@ def _validate_col_for_column_table(
31703185
if col_for_color is None:
31713186
return None, None
31723187

3173-
if not labels and col_for_color in sdata[element_name].columns:
3174-
table_name = None
3175-
elif table_name is not None:
3188+
if not labels and col_for_color in sdata[element_name].columns and table_name is None:
3189+
return col_for_color, None
3190+
if table_name is not None:
31763191
tables = get_element_annotators(sdata, element_name)
31773192
if table_name not in tables:
31783193
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")

tests/pl/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,41 @@ def test_color_column_collision_on_annotating_table_raises():
420420
sdata.pl.render_shapes("s", color="orange")
421421

422422
sdata.pl.render_shapes("s", color="#ffa500")
423+
424+
425+
def test_explicit_table_name_honored_when_element_has_same_column():
426+
# regression test for #620: explicit table_name= must not be silently
427+
# discarded when the element has a same-named column with different values.
428+
shapes = ShapesModel.parse(
429+
gpd.GeoDataFrame(
430+
{
431+
"geometry": [Point(5, 5), Point(15, 5)],
432+
"radius": [2.0, 2.0],
433+
"cat": pd.Categorical(["X", "Y"]),
434+
}
435+
)
436+
)
437+
obs = pd.DataFrame(
438+
{
439+
"instance_id": [0, 1],
440+
"region": pd.Categorical(["s1", "s1"]),
441+
"cat": pd.Categorical(["A", "B"]),
442+
}
443+
)
444+
table = TableModel.parse(
445+
AnnData(X=np.zeros((2, 1)), obs=obs),
446+
region=["s1"],
447+
region_key="region",
448+
instance_key="instance_id",
449+
)
450+
sdata = SpatialData(shapes={"s1": shapes}, tables={"t": table})
451+
452+
fig, ax = plt.subplots()
453+
sdata.pl.render_shapes("s1", color="cat", table_name="t").pl.show(ax=ax)
454+
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["A", "B"]
455+
plt.close(fig)
456+
457+
fig, ax = plt.subplots()
458+
sdata.pl.render_shapes("s1", color="cat").pl.show(ax=ax)
459+
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["X", "Y"]
460+
plt.close(fig)

0 commit comments

Comments
 (0)