Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@
)


def _reset_index_preserving_existing_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Reset the index of a DataFrame, dropping the index if its name already exists as a column.

Avoids ``ValueError: cannot insert <name>, already exists`` when the index
name collides with an existing column (e.g. "EntityID" in Merfish data).
"""
if df.index.name is not None and df.index.name in df.columns:
return df.reset_index(drop=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here we are dropping the index instead of a column. This could be a problem if the column is different from the index.

What's the origin behind this problem? Is it that for the merfish data we should have had sdata['table'].index.name = None? That would be a better fix, while still raising an exception here.

return df.reset_index()


def get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]:
"""
Retrieve names of tables that annotate a SpatialElement in a SpatialData object.
Expand Down Expand Up @@ -388,7 +399,7 @@ def _inner_join_spatialelement_table(
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
obs = table.obs.reset_index()
obs = _reset_index_preserving_existing_columns(table.obs)
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
for element_type, name_element in element_dict.items():
Expand Down Expand Up @@ -469,7 +480,7 @@ def _left_join_spatialelement_table(
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
obs = table.obs.reset_index()
obs = _reset_index_preserving_existing_columns(table.obs)
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
for element_type, name_element in element_dict.items():
Expand Down
29 changes: 28 additions & 1 deletion tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import annsel as an
import geopandas as gpd
import numpy as np
import pandas as pd
import pytest
import shapely
from anndata import AnnData

from spatialdata import SpatialData, get_values, match_table_to_element
Expand All @@ -14,7 +16,7 @@
get_element_annotators,
join_spatialelement_table,
)
from spatialdata.models.models import TableModel
from spatialdata.models.models import ShapesModel, TableModel
from spatialdata.testing import assert_anndata_equal, assert_geodataframe_equal


Expand Down Expand Up @@ -1262,3 +1264,28 @@ def test_filter_by_table_query_complex_combination(complex_sdata):
assert ("circles", idx) in table_instance_ids
for idx in result["poly"].index:
assert ("poly", idx) in table_instance_ids


@pytest.mark.parametrize("how", ["inner", "left"])
def test_join_spatialelement_table_obs_index_name_collision(how):
"""join_spatialelement_table must not crash when obs index name matches an existing column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should crash in general to avoid a silent branching in the behavior. The ok case would be if both the index and the column have the same values.


Regression test for https://github.com/scverse/spatialdata/issues/1099.
"""
n = 5
shapes = ShapesModel.parse(
gpd.GeoDataFrame({"geometry": [shapely.Point(i, i) for i in range(n)], "radius": np.ones(n)})
)
obs = pd.DataFrame({"region": pd.Categorical(["shapes"] * n), "EntityID": np.arange(n), "cell_type": list("AABBC")})
table = AnnData(obs=obs)
table = TableModel.parse(table, region="shapes", region_key="region", instance_key="EntityID")
sdata = SpatialData(shapes={"shapes": shapes}, tables={"table": table})

# Introduce the conflicting state: index name == existing column name
sdata["table"].obs.index = pd.Index([str(i) for i in range(n)], name="EntityID")

element_dict, joined_table = join_spatialelement_table(
sdata=sdata, spatial_element_names="shapes", table_name="table", how=how
)
assert joined_table.n_obs == n
assert "shapes" in element_dict