Skip to content

Commit 33e1d4c

Browse files
committed
Add render_graph for spatial connectivity visualization
Adds sdata.pl.render_graph() to draw spatial graph edges from adjacency matrices stored in table.obsp, using element centroids as node positions. Supports shapes, points, and labels. Features: - Scalar / obs-categorical / obs-continuous / obsp-matrix edge coloring - Per-edge widths and alphas (scalar or from a weight matrix) - Group filtering (both endpoints must be in the requested groups) - Self-loop rendering as a CircleCollection - Integrated colorbar via the existing ColorbarSpec pipeline - Legend sharing with chained render_shapes / render_points - No new dependencies (no networkx)
1 parent 84e5da7 commit 33e1d4c

7 files changed

Lines changed: 949 additions & 26 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
import warnings
66
from collections import OrderedDict
7-
from collections.abc import Callable
7+
from collections.abc import Callable, Sequence
88
from copy import deepcopy
99
from pathlib import Path
1010
from typing import Any, Literal, cast
@@ -25,12 +25,14 @@
2525
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
2626
from spatialdata import get_extent
2727
from spatialdata._utils import _deprecation_alias
28+
from spatialdata.transformations.operations import get_transformation
2829
from xarray import DataArray, DataTree
2930

3031
from spatialdata_plot._accessor import register_spatial_data_accessor
3132
from spatialdata_plot._logging import _log_context
3233
from spatialdata_plot.pl.render import (
3334
_draw_channel_legend,
35+
_render_graph,
3436
_render_images,
3537
_render_labels,
3638
_render_points,
@@ -44,6 +46,7 @@
4446
ChannelLegendEntry,
4547
CmapParams,
4648
ColorbarSpec,
49+
GraphRenderParams,
4750
ImageRenderParams,
4851
LabelsRenderParams,
4952
LegendParams,
@@ -64,6 +67,7 @@
6467
_prepare_cmap_norm,
6568
_prepare_params_plot,
6669
_set_outline,
70+
_validate_graph_render_params,
6771
_validate_image_render_params,
6872
_validate_label_render_params,
6973
_validate_points_render_params,
@@ -856,6 +860,136 @@ def render_labels(
856860
n_steps += 1
857861
return sdata
858862

863+
def render_graph(
864+
self,
865+
element: str | None = None,
866+
color: ColorLike | None = None,
867+
*,
868+
connectivity_key: str = "spatial",
869+
obsp_key: str | None = None,
870+
palette: dict[str, str] | list[str] | str | None = None,
871+
na_color: ColorLike | None = "default",
872+
cmap: Colormap | str | None = None,
873+
norm: Normalize | None = None,
874+
groups: list[str] | str | None = None,
875+
group_key: str | None = None,
876+
edge_width: float | Literal["weight"] = 1.0,
877+
edge_alpha: float | Literal["weight"] = 1.0,
878+
weight_key: str | None = None,
879+
linestyle: str | Sequence[str] = "solid",
880+
rasterize: bool = True,
881+
include_self_loops: bool = False,
882+
colorbar: bool | str | None = "auto",
883+
colorbar_params: dict[str, object] | None = None,
884+
table_name: str | None = None,
885+
) -> sd.SpatialData:
886+
"""Render spatial graph edges between observations.
887+
888+
Draws edges from a connectivity matrix in ``table.obsp`` using
889+
centroid coordinates of the linked spatial element.
890+
891+
Parameters
892+
----------
893+
element : str | None
894+
Name of the shapes/points/labels element the graph connects.
895+
Auto-resolved from the table if omitted.
896+
color : ColorLike | None
897+
A color-like value applied to every edge, or the name of a
898+
``table.obs`` column. Categorical columns colour same-category
899+
edges by the shared value and cross-category edges by
900+
``na_color``. Continuous columns colour edges by the mean of
901+
their endpoint values. Defaults to grey when unset.
902+
connectivity_key : str, default "spatial"
903+
``table.obsp`` key. Tries ``key`` first, then ``f"{key}_connectivities"``.
904+
obsp_key : str | None
905+
``table.obsp`` matrix used as per-edge scalar; coloured via
906+
``cmap``/``norm``. Mutually exclusive with ``color``.
907+
palette : dict[str, str] | list[str] | str | None
908+
Palette for categorical obs coloring. Same as :meth:`render_shapes`.
909+
na_color : ColorLike | None, default "default"
910+
Colour for cross-category edges. ``None`` makes them transparent.
911+
cmap : Colormap | str | None
912+
Colormap for continuous edge coloring.
913+
norm : Normalize | None
914+
Pass ``Normalize(vmin=..., vmax=...)`` to clamp the colormap range.
915+
groups : list[str] | str | None
916+
Show only edges where **both** endpoints fall in these groups.
917+
Requires ``group_key``.
918+
group_key : str | None
919+
``table.obs`` column used for group filtering.
920+
edge_width : float | Literal["weight"], default 1.0
921+
Line width. Pass ``"weight"`` to scale by ``weight_key`` values
922+
into ``[0.5, 3.0]``.
923+
edge_alpha : float | Literal["weight"], default 1.0
924+
Transparency. Pass ``"weight"`` to scale into ``[0.2, 1.0]``.
925+
weight_key : str | None
926+
``table.obsp`` matrix providing per-edge weights. Defaults to
927+
``connectivity_key`` when omitted.
928+
linestyle : str | Sequence[str], default "solid"
929+
``LineCollection`` linestyle (scalar or per-edge).
930+
rasterize : bool, default True
931+
Rasterize the edge collection. Set ``False`` for vector output.
932+
include_self_loops : bool, default False
933+
Render diagonal entries of the connectivity matrix as circles.
934+
table_name : str | None
935+
Table containing the graph. Auto-discovered if omitted.
936+
937+
Returns
938+
-------
939+
sd.SpatialData
940+
Copy with rendering parameters stored in the plotting tree.
941+
942+
Notes
943+
-----
944+
Chaining with ``render_shapes``/``render_points`` on the same
945+
categorical column shares the legend; no dedicated edge legend is drawn.
946+
"""
947+
params = _validate_graph_render_params(
948+
self._sdata,
949+
element=element,
950+
connectivity_key=connectivity_key,
951+
obsp_key=obsp_key,
952+
weight_key=weight_key,
953+
palette=palette,
954+
na_color=na_color,
955+
cmap=cmap,
956+
norm=norm,
957+
table_name=table_name,
958+
color=color,
959+
edge_width=edge_width,
960+
edge_alpha=edge_alpha,
961+
groups=groups,
962+
group_key=group_key,
963+
)
964+
965+
sdata = self._copy()
966+
sdata = _verify_plotting_tree(sdata)
967+
n_steps = len(sdata.plotting_tree.keys())
968+
sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams(
969+
element=params["element"],
970+
connectivity_obsp_key=params["connectivity_obsp_key"],
971+
table_name=params["table_name"],
972+
color=params["color"],
973+
obs_col=params["obs_col"],
974+
obsp_key=params["obsp_key"],
975+
cmap_params=params["cmap_params"],
976+
palette_map=params["palette_map"],
977+
na_color=params["na_color"],
978+
color_source=params["color_source"],
979+
groups=params["groups"],
980+
group_key=params["group_key"],
981+
edge_width=params["edge_width"],
982+
edge_alpha=params["edge_alpha"],
983+
weight_key=params["weight_key"],
984+
linestyle=linestyle,
985+
rasterize=rasterize,
986+
include_self_loops=include_self_loops,
987+
zorder=n_steps,
988+
colorbar=colorbar,
989+
colorbar_params=colorbar_params,
990+
)
991+
return sdata
992+
859993
def show(
860994
self,
861995
coordinate_systems: list[str] | str | None = None,
@@ -1020,6 +1154,7 @@ def show(
10201154
"render_shapes",
10211155
"render_labels",
10221156
"render_points",
1157+
"render_graph",
10231158
]
10241159

10251160
# prepare rendering params
@@ -1340,6 +1475,21 @@ def _draw_colorbar(
13401475
rasterize=rasterize,
13411476
)
13421477

1478+
elif cmd == "render_graph":
1479+
graph_element = params_copy.element
1480+
element_in_cs = graph_element in sdata and cs in set(
1481+
get_transformation(sdata[graph_element], get_all=True).keys()
1482+
)
1483+
if element_in_cs:
1484+
_render_graph(
1485+
sdata=sdata,
1486+
render_params=params_copy,
1487+
coordinate_system=cs,
1488+
ax=ax,
1489+
legend_params=legend_params_obj,
1490+
colorbar_requests=axis_colorbar_requests,
1491+
)
1492+
13431493
if title is None:
13441494
t = cs
13451495
elif len(title) == 1:

0 commit comments

Comments
 (0)