@@ -1823,7 +1823,6 @@ def _render_graph(
18231823 render_params : GraphRenderParams ,
18241824 coordinate_system : str ,
18251825 ax : matplotlib .axes .SubplotBase ,
1826- ** kwargs : Any ,
18271826) -> None :
18281827 """Render spatial graph edges as a LineCollection on the given axes."""
18291828 from matplotlib .collections import LineCollection
@@ -1836,18 +1835,18 @@ def _render_graph(
18361835 # Get table and adjacency matrix
18371836 table = sdata [table_name ]
18381837 obsp_key = render_params .connectivity_key
1839- # _validate_graph_render_params already resolved the actual obsp key,
1840- # but we stored the prefix in render_params — re-resolve here
18411838 if obsp_key not in table .obsp :
1842- suffixed = f"{ obsp_key } _connectivities"
1843- if suffixed in table .obsp :
1844- obsp_key = suffixed
1845- else :
1846- logger .warning (f"Connectivity key '{ obsp_key } ' not found in table obsp. Skipping graph rendering." )
1847- return
1839+ logger .warning (f"Connectivity key '{ obsp_key } ' not found in table obsp. Skipping graph rendering." )
1840+ return
18481841
18491842 adjacency = table .obsp [obsp_key ]
18501843
1844+ if adjacency .shape [0 ] != table .n_obs :
1845+ raise ValueError (
1846+ f"Adjacency matrix shape { adjacency .shape } does not match table.n_obs ({ table .n_obs } ). "
1847+ "The graph must be computed on the full table, not a subset."
1848+ )
1849+
18511850 # Get the spatial element
18521851 if element_name in sdata .shapes :
18531852 element = sdata .shapes [element_name ]
@@ -1866,67 +1865,54 @@ def _render_graph(
18661865
18671866 centroid_coords = np .column_stack ([centroids_df ["x" ].values , centroids_df ["y" ].values ])
18681867
1869- # Align table observations to centroid positions
1870- # The table's instance_key maps obs rows to spatial element instances.
1871- # Centroids are ordered by element instance (e.g., label ID or GeoDataFrame index).
1868+ # Align table observations to centroid positions via instance_key.
1869+ # Build a coordinate array indexed by full-table row so edge lookups are O(1).
18721870 _ , region_key , instance_key = get_table_keys (table )
18731871
1874- # Filter table to only rows annotating this element
18751872 element_mask = table .obs [region_key ] == element_name if region_key is not None else np .ones (table .n_obs , dtype = bool )
1876- table_subset_indices = np .where (element_mask )[0 ]
18771873 instance_ids = table .obs [instance_key ].values [element_mask ]
1874+ table_subset_indices = np .where (element_mask )[0 ]
18781875
1879- # Build mapping from instance_id to centroid row index
1880- # For shapes/points, centroids follow the GeoDataFrame/DataFrame index order.
1881- # For labels, centroids follow unique label IDs (excluding background).
18821876 centroid_ids = centroids_df .index .values if hasattr (centroids_df , "index" ) else np .arange (len (centroids_df ))
1877+ id_to_centroid_row = {cid : row for row , cid in enumerate (centroid_ids )}
18831878
1884- id_to_centroid_row = {}
1885- for row , cid in enumerate (centroid_ids ):
1886- id_to_centroid_row [cid ] = row
1887-
1888- # Map each table obs (that annotates this element) to a centroid coordinate
1889- obs_to_coord = {}
1879+ # has_coord[i] is True if table row i has a valid centroid
1880+ has_coord = np .zeros (table .n_obs , dtype = bool )
1881+ coord_lookup = np .full ((table .n_obs , 2 ), np .nan )
18901882 for table_row , iid in zip (table_subset_indices , instance_ids , strict = True ):
18911883 if iid in id_to_centroid_row :
1892- obs_to_coord [table_row ] = centroid_coords [id_to_centroid_row [iid ]]
1884+ has_coord [table_row ] = True
1885+ coord_lookup [table_row ] = centroid_coords [id_to_centroid_row [iid ]]
18931886
1894- # Apply group filtering
1887+ # Apply group filtering: narrow has_coord to only rows in requested groups
18951888 groups = render_params .groups
18961889 group_key = render_params .group_key
18971890 if groups is not None and group_key is not None :
18981891 group_values = table .obs [group_key ].values
1899- group_set = set (groups )
1900- obs_in_groups = {idx for idx in obs_to_coord if group_values [idx ] in group_set }
1901- else :
1902- obs_in_groups = set (obs_to_coord .keys ())
1892+ in_groups = np .isin (group_values , groups )
1893+ has_coord &= in_groups
19031894
1904- # Extract edges from upper triangle (undirected graph — draw each edge once)
1905- adj_upper = triu (adjacency , k = 0 )
1895+ # Extract edges from upper triangle (undirected — draw each edge once, skip self-loops )
1896+ adj_upper = triu (adjacency , k = 1 )
19061897 rows , cols = adj_upper .nonzero ()
19071898
1908- # Build line segments for edges where both endpoints are valid
1909- segments = []
1910- for r , c in zip (rows , cols , strict = True ):
1911- if r == c :
1912- continue # skip self-loops
1913- if r in obs_in_groups and c in obs_in_groups and r in obs_to_coord and c in obs_to_coord :
1914- segments .append ([obs_to_coord [r ], obs_to_coord [c ]])
1915-
1916- if not segments :
1899+ # Vectorized filter: keep edges where both endpoints are valid
1900+ edge_mask = has_coord [rows ] & has_coord [cols ]
1901+ if not edge_mask .any ():
19171902 return
19181903
1919- segments_arr = np .array (segments )
1904+ valid_rows = rows [edge_mask ]
1905+ valid_cols = cols [edge_mask ]
1906+ segments = np .stack ([coord_lookup [valid_rows ], coord_lookup [valid_cols ]], axis = 1 )
19201907
19211908 edge_color = render_params .color .get_hex () if render_params .color is not None else "#808080"
19221909
19231910 lc = LineCollection (
1924- segments_arr ,
1911+ segments ,
19251912 linewidths = render_params .edge_width ,
19261913 colors = edge_color ,
19271914 alpha = render_params .edge_alpha ,
19281915 zorder = render_params .zorder ,
1929- ** kwargs ,
19301916 )
19311917 lc .set_rasterized (True )
19321918 ax .add_collection (lc )
0 commit comments