Skip to content

Commit db17a4f

Browse files
authored
Honor explicit table_name= when element has same-named column (#666)
1 parent 13c3cde commit db17a4f

2 files changed

Lines changed: 57 additions & 3 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,13 @@ def _set_color_source_vec(
10781078
table_name=table_name,
10791079
)
10801080

1081+
# When both the element's own dataframe and the chosen table contain a
1082+
# column with this name, an explicit `table_name=` resolves the ambiguity —
1083+
# keep only the table origin and skip the multi-origin error below.
1084+
explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins)
1085+
if explicit_table_shadows_df:
1086+
origins = [o for o in origins if o.origin != "df"]
1087+
10811088
if len(origins) > 1:
10821089
raise ValueError(
10831090
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
@@ -1094,6 +1101,15 @@ def _set_color_source_vec(
10941101
)
10951102
if preloaded_color_data is not None:
10961103
color_source_vector = preloaded_color_data
1104+
elif explicit_table_shadows_df:
1105+
# Pass the table as `element` so upstream `get_values` skips the
1106+
# element-column lookup and avoids the multi-origin error.
1107+
color_source_vector = get_values(
1108+
value_key=value_to_plot,
1109+
element=sdata[table_name],
1110+
element_name=element_name,
1111+
table_layer=table_layer,
1112+
)[value_to_plot]
10971113
else:
10981114
color_source_vector = get_values(
10991115
value_key=value_to_plot,
@@ -3170,9 +3186,9 @@ def _validate_col_for_column_table(
31703186
if col_for_color is None:
31713187
return None, None
31723188

3173-
if not labels and col_for_color in sdata[element_name].columns:
3174-
table_name = None
3175-
elif table_name is not None:
3189+
if not labels and col_for_color in sdata[element_name].columns and table_name is None:
3190+
return col_for_color, None
3191+
if table_name is not None:
31763192
tables = get_element_annotators(sdata, element_name)
31773193
if table_name not in tables:
31783194
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")

tests/pl/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,41 @@ def test_color_column_collision_on_annotating_table_raises():
420420
sdata.pl.render_shapes("s", color="orange")
421421

422422
sdata.pl.render_shapes("s", color="#ffa500")
423+
424+
425+
def test_explicit_table_name_honored_when_element_has_same_column():
426+
# regression test for #620: explicit table_name= must not be silently
427+
# discarded when the element has a same-named column with different values.
428+
shapes = ShapesModel.parse(
429+
gpd.GeoDataFrame(
430+
{
431+
"geometry": [Point(5, 5), Point(15, 5)],
432+
"radius": [2.0, 2.0],
433+
"cat": pd.Categorical(["X", "Y"]),
434+
}
435+
)
436+
)
437+
obs = pd.DataFrame(
438+
{
439+
"instance_id": [0, 1],
440+
"region": pd.Categorical(["s1", "s1"]),
441+
"cat": pd.Categorical(["A", "B"]),
442+
}
443+
)
444+
table = TableModel.parse(
445+
AnnData(X=np.zeros((2, 1)), obs=obs),
446+
region=["s1"],
447+
region_key="region",
448+
instance_key="instance_id",
449+
)
450+
sdata = SpatialData(shapes={"s1": shapes}, tables={"t": table})
451+
452+
fig, ax = plt.subplots()
453+
sdata.pl.render_shapes("s1", color="cat", table_name="t").pl.show(ax=ax)
454+
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["A", "B"]
455+
plt.close(fig)
456+
457+
fig, ax = plt.subplots()
458+
sdata.pl.render_shapes("s1", color="cat").pl.show(ax=ax)
459+
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["X", "Y"]
460+
plt.close(fig)

0 commit comments

Comments
 (0)