|
4 | 4 | import sys |
5 | 5 | import warnings |
6 | 6 | from collections import OrderedDict |
7 | | -from collections.abc import Callable |
| 7 | +from collections.abc import Callable, Sequence |
8 | 8 | from copy import deepcopy |
9 | 9 | from pathlib import Path |
10 | 10 | from typing import Any, Literal, cast |
|
25 | 25 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes |
26 | 26 | from spatialdata import get_extent |
27 | 27 | from spatialdata._utils import _deprecation_alias |
| 28 | +from spatialdata.transformations.operations import get_transformation |
28 | 29 | from xarray import DataArray, DataTree |
29 | 30 |
|
30 | 31 | from spatialdata_plot._accessor import register_spatial_data_accessor |
31 | 32 | from spatialdata_plot._logging import _log_context |
32 | 33 | from spatialdata_plot.pl.render import ( |
33 | 34 | _draw_channel_legend, |
| 35 | + _render_graph, |
34 | 36 | _render_images, |
35 | 37 | _render_labels, |
36 | 38 | _render_points, |
|
44 | 46 | ChannelLegendEntry, |
45 | 47 | CmapParams, |
46 | 48 | ColorbarSpec, |
| 49 | + GraphRenderParams, |
47 | 50 | ImageRenderParams, |
48 | 51 | LabelsRenderParams, |
49 | 52 | LegendParams, |
|
64 | 67 | _prepare_cmap_norm, |
65 | 68 | _prepare_params_plot, |
66 | 69 | _set_outline, |
| 70 | + _validate_graph_render_params, |
67 | 71 | _validate_image_render_params, |
68 | 72 | _validate_label_render_params, |
69 | 73 | _validate_points_render_params, |
@@ -856,6 +860,136 @@ def render_labels( |
856 | 860 | n_steps += 1 |
857 | 861 | return sdata |
858 | 862 |
|
| 863 | + def render_graph( |
| 864 | + self, |
| 865 | + element: str | None = None, |
| 866 | + color: ColorLike | None = None, |
| 867 | + *, |
| 868 | + connectivity_key: str = "spatial", |
| 869 | + obsp_key: str | None = None, |
| 870 | + palette: dict[str, str] | list[str] | str | None = None, |
| 871 | + na_color: ColorLike | None = "default", |
| 872 | + cmap: Colormap | str | None = None, |
| 873 | + norm: Normalize | None = None, |
| 874 | + groups: list[str] | str | None = None, |
| 875 | + group_key: str | None = None, |
| 876 | + edge_width: float | Literal["weight"] = 1.0, |
| 877 | + edge_alpha: float | Literal["weight"] = 1.0, |
| 878 | + weight_key: str | None = None, |
| 879 | + linestyle: str | Sequence[str] = "solid", |
| 880 | + rasterize: bool = True, |
| 881 | + include_self_loops: bool = False, |
| 882 | + colorbar: bool | str | None = "auto", |
| 883 | + colorbar_params: dict[str, object] | None = None, |
| 884 | + table_name: str | None = None, |
| 885 | + ) -> sd.SpatialData: |
| 886 | + """Render spatial graph edges between observations. |
| 887 | +
|
| 888 | + Draws edges from a connectivity matrix in ``table.obsp`` using |
| 889 | + centroid coordinates of the linked spatial element. |
| 890 | +
|
| 891 | + Parameters |
| 892 | + ---------- |
| 893 | + element : str | None |
| 894 | + Name of the shapes/points/labels element the graph connects. |
| 895 | + Auto-resolved from the table if omitted. |
| 896 | + color : ColorLike | None |
| 897 | + A color-like value applied to every edge, or the name of a |
| 898 | + ``table.obs`` column. Categorical columns colour same-category |
| 899 | + edges by the shared value and cross-category edges by |
| 900 | + ``na_color``. Continuous columns colour edges by the mean of |
| 901 | + their endpoint values. Defaults to grey when unset. |
| 902 | + connectivity_key : str, default "spatial" |
| 903 | + ``table.obsp`` key. Tries ``key`` first, then ``f"{key}_connectivities"``. |
| 904 | + obsp_key : str | None |
| 905 | + ``table.obsp`` matrix used as per-edge scalar; coloured via |
| 906 | + ``cmap``/``norm``. Mutually exclusive with ``color``. |
| 907 | + palette : dict[str, str] | list[str] | str | None |
| 908 | + Palette for categorical obs coloring. Same as :meth:`render_shapes`. |
| 909 | + na_color : ColorLike | None, default "default" |
| 910 | + Colour for cross-category edges. ``None`` makes them transparent. |
| 911 | + cmap : Colormap | str | None |
| 912 | + Colormap for continuous edge coloring. |
| 913 | + norm : Normalize | None |
| 914 | + Pass ``Normalize(vmin=..., vmax=...)`` to clamp the colormap range. |
| 915 | + groups : list[str] | str | None |
| 916 | + Show only edges where **both** endpoints fall in these groups. |
| 917 | + Requires ``group_key``. |
| 918 | + group_key : str | None |
| 919 | + ``table.obs`` column used for group filtering. |
| 920 | + edge_width : float | Literal["weight"], default 1.0 |
| 921 | + Line width. Pass ``"weight"`` to scale by ``weight_key`` values |
| 922 | + into ``[0.5, 3.0]``. |
| 923 | + edge_alpha : float | Literal["weight"], default 1.0 |
| 924 | + Transparency. Pass ``"weight"`` to scale into ``[0.2, 1.0]``. |
| 925 | + weight_key : str | None |
| 926 | + ``table.obsp`` matrix providing per-edge weights. Defaults to |
| 927 | + ``connectivity_key`` when omitted. |
| 928 | + linestyle : str | Sequence[str], default "solid" |
| 929 | + ``LineCollection`` linestyle (scalar or per-edge). |
| 930 | + rasterize : bool, default True |
| 931 | + Rasterize the edge collection. Set ``False`` for vector output. |
| 932 | + include_self_loops : bool, default False |
| 933 | + Render diagonal entries of the connectivity matrix as circles. |
| 934 | + table_name : str | None |
| 935 | + Table containing the graph. Auto-discovered if omitted. |
| 936 | +
|
| 937 | + Returns |
| 938 | + ------- |
| 939 | + sd.SpatialData |
| 940 | + Copy with rendering parameters stored in the plotting tree. |
| 941 | +
|
| 942 | + Notes |
| 943 | + ----- |
| 944 | + Chaining with ``render_shapes``/``render_points`` on the same |
| 945 | + categorical column shares the legend; no dedicated edge legend is drawn. |
| 946 | + """ |
| 947 | + params = _validate_graph_render_params( |
| 948 | + self._sdata, |
| 949 | + element=element, |
| 950 | + connectivity_key=connectivity_key, |
| 951 | + obsp_key=obsp_key, |
| 952 | + weight_key=weight_key, |
| 953 | + palette=palette, |
| 954 | + na_color=na_color, |
| 955 | + cmap=cmap, |
| 956 | + norm=norm, |
| 957 | + table_name=table_name, |
| 958 | + color=color, |
| 959 | + edge_width=edge_width, |
| 960 | + edge_alpha=edge_alpha, |
| 961 | + groups=groups, |
| 962 | + group_key=group_key, |
| 963 | + ) |
| 964 | + |
| 965 | + sdata = self._copy() |
| 966 | + sdata = _verify_plotting_tree(sdata) |
| 967 | + n_steps = len(sdata.plotting_tree.keys()) |
| 968 | + sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams( |
| 969 | + element=params["element"], |
| 970 | + connectivity_obsp_key=params["connectivity_obsp_key"], |
| 971 | + table_name=params["table_name"], |
| 972 | + color=params["color"], |
| 973 | + obs_col=params["obs_col"], |
| 974 | + obsp_key=params["obsp_key"], |
| 975 | + cmap_params=params["cmap_params"], |
| 976 | + palette_map=params["palette_map"], |
| 977 | + na_color=params["na_color"], |
| 978 | + color_source=params["color_source"], |
| 979 | + groups=params["groups"], |
| 980 | + group_key=params["group_key"], |
| 981 | + edge_width=params["edge_width"], |
| 982 | + edge_alpha=params["edge_alpha"], |
| 983 | + weight_key=params["weight_key"], |
| 984 | + linestyle=linestyle, |
| 985 | + rasterize=rasterize, |
| 986 | + include_self_loops=include_self_loops, |
| 987 | + zorder=n_steps, |
| 988 | + colorbar=colorbar, |
| 989 | + colorbar_params=colorbar_params, |
| 990 | + ) |
| 991 | + return sdata |
| 992 | + |
859 | 993 | def show( |
860 | 994 | self, |
861 | 995 | coordinate_systems: list[str] | str | None = None, |
@@ -1020,6 +1154,7 @@ def show( |
1020 | 1154 | "render_shapes", |
1021 | 1155 | "render_labels", |
1022 | 1156 | "render_points", |
| 1157 | + "render_graph", |
1023 | 1158 | ] |
1024 | 1159 |
|
1025 | 1160 | # prepare rendering params |
@@ -1340,6 +1475,21 @@ def _draw_colorbar( |
1340 | 1475 | rasterize=rasterize, |
1341 | 1476 | ) |
1342 | 1477 |
|
| 1478 | + elif cmd == "render_graph": |
| 1479 | + graph_element = params_copy.element |
| 1480 | + element_in_cs = graph_element in sdata and cs in set( |
| 1481 | + get_transformation(sdata[graph_element], get_all=True).keys() |
| 1482 | + ) |
| 1483 | + if element_in_cs: |
| 1484 | + _render_graph( |
| 1485 | + sdata=sdata, |
| 1486 | + render_params=params_copy, |
| 1487 | + coordinate_system=cs, |
| 1488 | + ax=ax, |
| 1489 | + legend_params=legend_params_obj, |
| 1490 | + colorbar_requests=axis_colorbar_requests, |
| 1491 | + ) |
| 1492 | + |
1343 | 1493 | if title is None: |
1344 | 1494 | t = cs |
1345 | 1495 | elif len(title) == 1: |
|
0 commit comments