Skip to content

Commit 73993fb

Browse files
Add network centrality. Update plotting signatures.
1 parent 4ac87b8 commit 73993fb

11 files changed

Lines changed: 8107 additions & 156 deletions

File tree

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* Interactive cell selector
1111
* UMAP points have no stroke
1212
* `compute_archetype_feature_specificity()`: `key_added` > `key_prefix`
13+
* Make archetype specificity and network centrality optional in `run_actionet()`
1314

1415
## Secondary
1516
* Consolidate normalization code-paths

src/actionet/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""ACTIONet: Single-cell multi-resolution data analysis toolkit.
22
33
Python bindings for the ACTIONet C++ backend (libactionet) via pybind11.
4-
Uses AnnData as the core data container, integrates with the scanpy ecosystem.
4+
Uses AnnData as the core data container.
55
66
System build requirements: CMake >= 3.19, C++17 compiler, BLAS/LAPACK,
77
HDF5 (C library), and OpenMP.
@@ -18,6 +18,7 @@
1818
from .core import (
1919
run_action,
2020
build_network,
21+
compute_network_centrality,
2122
compute_network_diffusion,
2223
layout_network,
2324
)
@@ -120,6 +121,7 @@
120121
"run_action",
121122
"build_network",
122123
"cluster_network",
124+
"compute_network_centrality",
123125
"compute_network_diffusion",
124126
"compute_feature_specificity",
125127
"compute_archetype_feature_specificity",

src/actionet/core.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Literal, Optional, Union
44
import numpy as np
5+
import pandas as pd
56
import scipy.sparse as sp
67
from anndata import AnnData
78

@@ -319,6 +320,182 @@ def compute_network_diffusion(
319320
return None
320321

321322

323+
def compute_network_centrality(
324+
X: Union[AnnData, sp.spmatrix],
325+
algorithm: Literal["coreness", "pagerank", "local_coreness", "local_pagerank"] = "coreness",
326+
labels: Union[str, np.ndarray, None] = None,
327+
alpha: float = 0.9,
328+
max_iter: int = 5,
329+
tol: float = 1e-8,
330+
n_threads: int = 0,
331+
network_key: str = "actionet",
332+
key_added: Optional[str] = None,
333+
return_raw: bool = False,
334+
inplace: bool = True,
335+
) -> Optional[Union[AnnData, np.ndarray]]:
336+
"""
337+
Compute network centrality scores.
338+
339+
Supports global measures (coreness, PageRank) and label-aware local
340+
variants (local coreness, local PageRank).
341+
342+
Parameters
343+
----------
344+
X
345+
Either an AnnData object (network looked up from ``X.obsp[network_key]``)
346+
or a raw sparse graph matrix. When a sparse matrix is passed, the
347+
result is always returned as a raw array regardless of ``return_raw``;
348+
``inplace``, ``key_added``, and ``network_key`` are ignored.
349+
algorithm
350+
Centrality algorithm:
351+
352+
* ``"coreness"`` -- k-shell decomposition.
353+
* ``"pagerank"`` -- global PageRank (uniform teleport).
354+
* ``"local_coreness"`` -- per-label subgraph coreness (requires
355+
``labels``).
356+
* ``"local_pagerank"`` -- per-label PageRank, taking the maximum
357+
across labels per cell (requires ``labels``).
358+
labels
359+
Cell-level labels required for ``local_coreness`` and
360+
``local_pagerank``. When ``X`` is an AnnData, this may be a string
361+
key into ``X.obs``; otherwise it must be an array-like of length
362+
``n_cells``. Ignored for global algorithms.
363+
alpha
364+
Damping factor for PageRank variants (clamped to [0, 0.99]).
365+
max_iter
366+
Maximum diffusion iterations (PageRank variants only).
367+
tol
368+
Convergence tolerance (PageRank variants only).
369+
n_threads
370+
Number of threads (0 = auto).
371+
network_key
372+
Key in ``X.obsp`` containing the graph. Ignored when ``X`` is a
373+
sparse matrix.
374+
key_added
375+
Key to store centrality in ``X.obs``. Defaults to
376+
``"{algorithm}_{network_key}"``. Ignored when ``return_raw=True``
377+
or when ``X`` is a sparse matrix.
378+
return_raw
379+
If ``True``, return the centrality array directly without writing
380+
to the AnnData. Always treated as ``True`` when ``X`` is a sparse
381+
matrix.
382+
inplace
383+
If ``True``, modifies the AnnData object in place. Ignored when
384+
``return_raw=True`` or when ``X`` is a sparse matrix.
385+
386+
Returns
387+
-------
388+
None
389+
When ``X`` is AnnData, ``inplace=True``, and ``return_raw=False``.
390+
AnnData
391+
A modified copy when ``inplace=False`` and ``return_raw=False``.
392+
np.ndarray
393+
Centrality array when ``return_raw=True`` or ``X`` is a sparse
394+
matrix.
395+
396+
Updates AnnData
397+
---------------
398+
adata.obs[key_added] : np.ndarray
399+
Per-cell centrality scores.
400+
"""
401+
valid_algorithms = {"coreness", "pagerank", "local_coreness", "local_pagerank"}
402+
if algorithm not in valid_algorithms:
403+
raise ValueError(
404+
f"Invalid algorithm '{algorithm}'. Must be one of {sorted(valid_algorithms)}."
405+
)
406+
407+
is_anndata = isinstance(X, AnnData)
408+
409+
if is_anndata:
410+
if network_key not in X.obsp:
411+
raise ValueError(f"Network '{network_key}' not found. Run build_network first.")
412+
G = X.obsp[network_key]
413+
else:
414+
G = X
415+
416+
n_cells = G.shape[1]
417+
418+
if algorithm in ("local_coreness", "local_pagerank"):
419+
if labels is None:
420+
raise ValueError(
421+
f"'labels' is required when algorithm='{algorithm}'."
422+
)
423+
424+
# Resolve labels to a numeric assignment vector when needed.
425+
assignments: Optional[np.ndarray] = None
426+
if labels is not None:
427+
if isinstance(labels, str):
428+
if not is_anndata:
429+
raise ValueError(
430+
"`labels` must be an array when `X` is a sparse matrix, not a string key."
431+
)
432+
if labels not in X.obs:
433+
raise ValueError(f"Labels column '{labels}' not found in adata.obs.")
434+
raw_labels = X.obs[labels].values
435+
else:
436+
raw_labels = np.asarray(labels)
437+
assignments = pd.Categorical(raw_labels).codes.astype(np.int32)
438+
439+
# Clamp alpha for PageRank variants.
440+
if algorithm in ("pagerank", "local_pagerank"):
441+
alpha = float(np.clip(alpha, 0.0, 0.99))
442+
443+
# --- Compute centrality ------------------------------------------------
444+
if algorithm == "coreness":
445+
centrality = np.asarray(_core.compute_coreness(G), dtype=np.float64)
446+
447+
elif algorithm == "pagerank":
448+
uniform = np.full((n_cells, 1), 1.0 / n_cells)
449+
centrality = np.asarray(
450+
_core.compute_network_diffusion(
451+
G=G, X0=uniform, alpha=alpha, max_it=max_iter,
452+
thread_no=n_threads, approx=True, norm_method=0, tol=tol,
453+
)
454+
).ravel()
455+
456+
elif algorithm == "local_coreness":
457+
centrality = np.asarray(
458+
_core.compute_archetype_centrality(G, assignments)
459+
)
460+
461+
elif algorithm == "local_pagerank":
462+
unique_codes = np.unique(assignments)
463+
n_groups = len(unique_codes)
464+
design = np.zeros((n_cells, n_groups), dtype=np.float64)
465+
for col_idx, code in enumerate(unique_codes):
466+
mask = assignments == code
467+
design[mask, col_idx] = 1.0
468+
col_sums = design.sum(axis=0)
469+
col_sums[col_sums == 0] = 1.0
470+
design /= col_sums
471+
472+
design = np.ascontiguousarray(design)
473+
scores = np.asarray(
474+
_core.compute_network_diffusion(
475+
G=G, X0=design, alpha=alpha, max_it=max_iter,
476+
thread_no=n_threads, approx=True, norm_method=0, tol=tol,
477+
)
478+
)
479+
col_max = scores.max(axis=0)
480+
col_max[col_max == 0] = 1.0
481+
scores /= col_max
482+
centrality = scores.max(axis=1)
483+
484+
centrality = centrality.ravel()
485+
486+
# --- Return / persist --------------------------------------------------
487+
if not is_anndata or return_raw:
488+
return centrality
489+
490+
adata = X if inplace else X.copy()
491+
if key_added is None:
492+
key_added = f"{algorithm}_{network_key}"
493+
persist_updates(adata, obs={key_added: centrality})
494+
if not inplace:
495+
return adata
496+
return None
497+
498+
322499
def layout_network(
323500
adata: AnnData,
324501
network_key: str = "actionet",

src/actionet/pipeline.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .core import (
88
run_action,
99
build_network,
10+
compute_network_centrality,
1011
compute_network_diffusion,
1112
layout_network,
1213
)
@@ -47,10 +48,9 @@ def run_actionet(
4748
This function executes the full ACTIONet workflow including:
4849
1. ACTION archetypal analysis
4950
2. Network construction
50-
3. Network centrality (TODO)
51-
4. Network-based diffusion
52-
5. 2D/3D layout generation
53-
6. Node color computation
51+
3. Network-based diffusion
52+
4. 2D/3D layout generation
53+
5. Node color computation
5454
5555
Parameters
5656
----------
@@ -146,10 +146,10 @@ def run_actionet(
146146
147147
Examples
148148
--------
149-
>>> import actionet as act
150-
>>> import scanpy as sc
151-
>>> adata = sc.read_h5ad("data.h5ad")
152-
>>> adata = act.run_actionet(adata, k_max=50, inplace=False)
149+
>>> import actionet as an
150+
>>> import anndata as ad
151+
>>> adata = ad.read_h5ad("data.h5ad")
152+
>>> adata = an.run_actionet(adata, k_max=50, inplace=False)
153153
>>> print(adata.obs['assigned_archetype'])
154154
155155
See Also
@@ -195,8 +195,16 @@ def run_actionet(
195195
inplace=True,
196196
)
197197

198-
# Step 3: Compute network centrality
199-
# TODO: Implement network_centrality
198+
# # Step 3: Compute network centrality
199+
# print("Computing network centrality...")
200+
# compute_network_centrality(
201+
# adata,
202+
# algorithm="pagerank",
203+
# network_key=network_key,
204+
# key_added="node_centrality",
205+
# n_threads=n_threads,
206+
# inplace=True,
207+
# )
200208

201209
# Step 4: Smooth archetype footprints via network diffusion
202210
print("Computing archetype footprints via diffusion...")

src/actionet/plotting/feature_expression.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def plot_feature_expression(
121121
trans_fac: float = 3.0,
122122
cmap: Union[str, Sequence[str]] = "magma",
123123
size: float = 1.0,
124-
net_slot: str = "actionet",
124+
network_key: str = "actionet",
125125
basis: str = "umap_2d_actionet",
126126
single_plot: bool = False,
127127
legend: bool = False,
@@ -171,7 +171,7 @@ def plot_feature_expression(
171171
Continuous palette for expression values.
172172
size
173173
Marker size for UMAP scatter.
174-
net_slot
174+
network_key
175175
Key in ``adata.obsp`` for the ACTIONet network.
176176
basis
177177
Key in ``adata.obsm`` containing 2D coordinates.
@@ -221,7 +221,7 @@ def plot_feature_expression(
221221
features=requested,
222222
method=method,
223223
features_use=features_use,
224-
network_key=net_slot,
224+
network_key=network_key,
225225
layer=layer,
226226
alpha=alpha,
227227
n_threads=n_threads,
@@ -296,7 +296,7 @@ def plot_feature_expression_raster(
296296
trans_fac: float = 3.0,
297297
cmap: Union[str, Sequence[str]] = "magma",
298298
size: float = 1.0,
299-
net_slot: str = "actionet",
299+
network_key: str = "actionet",
300300
basis: str = "umap_2d_actionet",
301301
single_plot: bool = False,
302302
legend: bool = False,
@@ -346,7 +346,7 @@ def plot_feature_expression_raster(
346346
Continuous palette for expression values.
347347
size
348348
Marker size for UMAP scatter.
349-
net_slot
349+
network_key
350350
Key in ``adata.obsp`` for the ACTIONet network.
351351
basis
352352
Key in ``adata.obsm`` containing 2D coordinates.
@@ -391,7 +391,7 @@ def plot_feature_expression_raster(
391391
features=requested,
392392
method=method,
393393
features_use=features_use,
394-
network_key=net_slot,
394+
network_key=network_key,
395395
layer=layer,
396396
alpha=alpha,
397397
n_threads=n_threads,
@@ -448,7 +448,7 @@ def plot_feature_expression_raster(
448448
basis=basis,
449449
alpha=1.0,
450450
fig_dpi=100.0,
451-
figsize=(panel_width, panel_height),
451+
fig_size=(panel_width, panel_height),
452452
trans_attr=trans_attr,
453453
trans_fac=trans_fac,
454454
trans_th=trans_th,
@@ -467,9 +467,9 @@ def plot_feature_expression_raster(
467467
order=None,
468468
na_color="#cccccc",
469469
hide_na=False,
470-
add_text_labels=False,
471-
label_text_size=9.0,
472-
nudge_text_labels=False,
470+
text_labels=False,
471+
text_label_size=9.0,
472+
nudge_text=False,
473473
)
474474

475475
total_cells = nrow * ncol

0 commit comments

Comments
 (0)