Skip to content

Commit ed78508

Browse files
timtreisclaude
andcommitted
Address adversarial review findings
- Drop **kwargs from render_graph() and _render_graph() — the deferred execution pattern can't forward arbitrary kwargs, matching how other render functions use kwargs.get() for named options only - Add adjacency matrix shape validation against table.n_obs to catch subset graphs stored in multi-region tables - Fix CS dispatch to use get_transformation() for element-level CS membership instead of type-level has_shapes/has_labels flags - Vectorize edge filtering: replace Python-loop set lookups with numpy boolean arrays (has_coord + np.isin for groups) - Remove dead line (dict overwritten immediately) - Use triu(k=1) to skip self-loops at the sparse level Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 855ed32 commit ed78508

2 files changed

Lines changed: 32 additions & 52 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
2626
from spatialdata import get_extent
2727
from spatialdata._utils import _deprecation_alias
28+
from spatialdata.transformations.operations import get_transformation
2829
from xarray import DataArray, DataTree
2930

3031
from spatialdata_plot._accessor import register_spatial_data_accessor
@@ -875,7 +876,6 @@ def render_graph(
875876
edge_width: float = 1.0,
876877
edge_alpha: float = 1.0,
877878
table_name: str | None = None,
878-
**kwargs: Any,
879879
) -> sd.SpatialData:
880880
"""Render spatial graph edges between observations.
881881
@@ -904,8 +904,6 @@ def render_graph(
904904
Transparency for edges (0 = invisible, 1 = opaque).
905905
table_name : str | None, optional
906906
Table containing the graph. Auto-discovered if not given.
907-
**kwargs
908-
Forwarded to :class:`matplotlib.collections.LineCollection`.
909907
910908
Returns
911909
-------
@@ -1392,13 +1390,9 @@ def _draw_colorbar(
13921390
)
13931391

13941392
elif cmd == "render_graph":
1395-
# Graph rendering: resolve which element the graph connects,
1396-
# check if that element exists in this CS.
13971393
graph_element = params_copy.element
1398-
element_in_cs = (
1399-
(graph_element in sdata.shapes and has_shapes)
1400-
or (graph_element in sdata.points and has_points)
1401-
or (graph_element in sdata.labels and has_labels)
1394+
element_in_cs = graph_element in sdata and cs in set(
1395+
get_transformation(sdata[graph_element], get_all=True).keys()
14021396
)
14031397
if element_in_cs:
14041398
_render_graph(

src/spatialdata_plot/pl/render.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,7 +1823,6 @@ def _render_graph(
18231823
render_params: GraphRenderParams,
18241824
coordinate_system: str,
18251825
ax: matplotlib.axes.SubplotBase,
1826-
**kwargs: Any,
18271826
) -> None:
18281827
"""Render spatial graph edges as a LineCollection on the given axes."""
18291828
from matplotlib.collections import LineCollection
@@ -1836,18 +1835,18 @@ def _render_graph(
18361835
# Get table and adjacency matrix
18371836
table = sdata[table_name]
18381837
obsp_key = render_params.connectivity_key
1839-
# _validate_graph_render_params already resolved the actual obsp key,
1840-
# but we stored the prefix in render_params — re-resolve here
18411838
if obsp_key not in table.obsp:
1842-
suffixed = f"{obsp_key}_connectivities"
1843-
if suffixed in table.obsp:
1844-
obsp_key = suffixed
1845-
else:
1846-
logger.warning(f"Connectivity key '{obsp_key}' not found in table obsp. Skipping graph rendering.")
1847-
return
1839+
logger.warning(f"Connectivity key '{obsp_key}' not found in table obsp. Skipping graph rendering.")
1840+
return
18481841

18491842
adjacency = table.obsp[obsp_key]
18501843

1844+
if adjacency.shape[0] != table.n_obs:
1845+
raise ValueError(
1846+
f"Adjacency matrix shape {adjacency.shape} does not match table.n_obs ({table.n_obs}). "
1847+
"The graph must be computed on the full table, not a subset."
1848+
)
1849+
18511850
# Get the spatial element
18521851
if element_name in sdata.shapes:
18531852
element = sdata.shapes[element_name]
@@ -1866,67 +1865,54 @@ def _render_graph(
18661865

18671866
centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values])
18681867

1869-
# Align table observations to centroid positions
1870-
# The table's instance_key maps obs rows to spatial element instances.
1871-
# Centroids are ordered by element instance (e.g., label ID or GeoDataFrame index).
1868+
# Align table observations to centroid positions via instance_key.
1869+
# Build a coordinate array indexed by full-table row so edge lookups are O(1).
18721870
_, region_key, instance_key = get_table_keys(table)
18731871

1874-
# Filter table to only rows annotating this element
18751872
element_mask = table.obs[region_key] == element_name if region_key is not None else np.ones(table.n_obs, dtype=bool)
1876-
table_subset_indices = np.where(element_mask)[0]
18771873
instance_ids = table.obs[instance_key].values[element_mask]
1874+
table_subset_indices = np.where(element_mask)[0]
18781875

1879-
# Build mapping from instance_id to centroid row index
1880-
# For shapes/points, centroids follow the GeoDataFrame/DataFrame index order.
1881-
# For labels, centroids follow unique label IDs (excluding background).
18821876
centroid_ids = centroids_df.index.values if hasattr(centroids_df, "index") else np.arange(len(centroids_df))
1877+
id_to_centroid_row = {cid: row for row, cid in enumerate(centroid_ids)}
18831878

1884-
id_to_centroid_row = {}
1885-
for row, cid in enumerate(centroid_ids):
1886-
id_to_centroid_row[cid] = row
1887-
1888-
# Map each table obs (that annotates this element) to a centroid coordinate
1889-
obs_to_coord = {}
1879+
# has_coord[i] is True if table row i has a valid centroid
1880+
has_coord = np.zeros(table.n_obs, dtype=bool)
1881+
coord_lookup = np.full((table.n_obs, 2), np.nan)
18901882
for table_row, iid in zip(table_subset_indices, instance_ids, strict=True):
18911883
if iid in id_to_centroid_row:
1892-
obs_to_coord[table_row] = centroid_coords[id_to_centroid_row[iid]]
1884+
has_coord[table_row] = True
1885+
coord_lookup[table_row] = centroid_coords[id_to_centroid_row[iid]]
18931886

1894-
# Apply group filtering
1887+
# Apply group filtering: narrow has_coord to only rows in requested groups
18951888
groups = render_params.groups
18961889
group_key = render_params.group_key
18971890
if groups is not None and group_key is not None:
18981891
group_values = table.obs[group_key].values
1899-
group_set = set(groups)
1900-
obs_in_groups = {idx for idx in obs_to_coord if group_values[idx] in group_set}
1901-
else:
1902-
obs_in_groups = set(obs_to_coord.keys())
1892+
in_groups = np.isin(group_values, groups)
1893+
has_coord &= in_groups
19031894

1904-
# Extract edges from upper triangle (undirected graph — draw each edge once)
1905-
adj_upper = triu(adjacency, k=0)
1895+
# Extract edges from upper triangle (undirected — draw each edge once, skip self-loops)
1896+
adj_upper = triu(adjacency, k=1)
19061897
rows, cols = adj_upper.nonzero()
19071898

1908-
# Build line segments for edges where both endpoints are valid
1909-
segments = []
1910-
for r, c in zip(rows, cols, strict=True):
1911-
if r == c:
1912-
continue # skip self-loops
1913-
if r in obs_in_groups and c in obs_in_groups and r in obs_to_coord and c in obs_to_coord:
1914-
segments.append([obs_to_coord[r], obs_to_coord[c]])
1915-
1916-
if not segments:
1899+
# Vectorized filter: keep edges where both endpoints are valid
1900+
edge_mask = has_coord[rows] & has_coord[cols]
1901+
if not edge_mask.any():
19171902
return
19181903

1919-
segments_arr = np.array(segments)
1904+
valid_rows = rows[edge_mask]
1905+
valid_cols = cols[edge_mask]
1906+
segments = np.stack([coord_lookup[valid_rows], coord_lookup[valid_cols]], axis=1)
19201907

19211908
edge_color = render_params.color.get_hex() if render_params.color is not None else "#808080"
19221909

19231910
lc = LineCollection(
1924-
segments_arr,
1911+
segments,
19251912
linewidths=render_params.edge_width,
19261913
colors=edge_color,
19271914
alpha=render_params.edge_alpha,
19281915
zorder=render_params.zorder,
1929-
**kwargs,
19301916
)
19311917
lc.set_rasterized(True)
19321918
ax.add_collection(lc)

0 commit comments

Comments
 (0)