|
2 | 2 |
|
3 | 3 | from typing import Literal, Optional, Union |
4 | 4 | import numpy as np |
| 5 | +import pandas as pd |
5 | 6 | import scipy.sparse as sp |
6 | 7 | from anndata import AnnData |
7 | 8 |
|
@@ -319,6 +320,182 @@ def compute_network_diffusion( |
319 | 320 | return None |
320 | 321 |
|
321 | 322 |
|
| 323 | +def compute_network_centrality( |
| 324 | + X: Union[AnnData, sp.spmatrix], |
| 325 | + algorithm: Literal["coreness", "pagerank", "local_coreness", "local_pagerank"] = "coreness", |
| 326 | + labels: Union[str, np.ndarray, None] = None, |
| 327 | + alpha: float = 0.9, |
| 328 | + max_iter: int = 5, |
| 329 | + tol: float = 1e-8, |
| 330 | + n_threads: int = 0, |
| 331 | + network_key: str = "actionet", |
| 332 | + key_added: Optional[str] = None, |
| 333 | + return_raw: bool = False, |
| 334 | + inplace: bool = True, |
| 335 | +) -> Optional[Union[AnnData, np.ndarray]]: |
| 336 | + """ |
| 337 | + Compute network centrality scores. |
| 338 | +
|
| 339 | + Supports global measures (coreness, PageRank) and label-aware local |
| 340 | + variants (local coreness, local PageRank). |
| 341 | +
|
| 342 | + Parameters |
| 343 | + ---------- |
| 344 | + X |
| 345 | + Either an AnnData object (network looked up from ``X.obsp[network_key]``) |
| 346 | + or a raw sparse graph matrix. When a sparse matrix is passed, the |
| 347 | + result is always returned as a raw array regardless of ``return_raw``; |
| 348 | + ``inplace``, ``key_added``, and ``network_key`` are ignored. |
| 349 | + algorithm |
| 350 | + Centrality algorithm: |
| 351 | +
|
| 352 | + * ``"coreness"`` -- k-shell decomposition. |
| 353 | + * ``"pagerank"`` -- global PageRank (uniform teleport). |
| 354 | + * ``"local_coreness"`` -- per-label subgraph coreness (requires |
| 355 | + ``labels``). |
| 356 | + * ``"local_pagerank"`` -- per-label PageRank, taking the maximum |
| 357 | + across labels per cell (requires ``labels``). |
| 358 | + labels |
| 359 | + Cell-level labels required for ``local_coreness`` and |
| 360 | + ``local_pagerank``. When ``X`` is an AnnData, this may be a string |
| 361 | + key into ``X.obs``; otherwise it must be an array-like of length |
| 362 | + ``n_cells``. Ignored for global algorithms. |
| 363 | + alpha |
| 364 | + Damping factor for PageRank variants (clamped to [0, 0.99]). |
| 365 | + max_iter |
| 366 | + Maximum diffusion iterations (PageRank variants only). |
| 367 | + tol |
| 368 | + Convergence tolerance (PageRank variants only). |
| 369 | + n_threads |
| 370 | + Number of threads (0 = auto). |
| 371 | + network_key |
| 372 | + Key in ``X.obsp`` containing the graph. Ignored when ``X`` is a |
| 373 | + sparse matrix. |
| 374 | + key_added |
| 375 | + Key to store centrality in ``X.obs``. Defaults to |
| 376 | + ``"{algorithm}_{network_key}"``. Ignored when ``return_raw=True`` |
| 377 | + or when ``X`` is a sparse matrix. |
| 378 | + return_raw |
| 379 | + If ``True``, return the centrality array directly without writing |
| 380 | + to the AnnData. Always treated as ``True`` when ``X`` is a sparse |
| 381 | + matrix. |
| 382 | + inplace |
| 383 | + If ``True``, modifies the AnnData object in place. Ignored when |
| 384 | + ``return_raw=True`` or when ``X`` is a sparse matrix. |
| 385 | +
|
| 386 | + Returns |
| 387 | + ------- |
| 388 | + None |
| 389 | + When ``X`` is AnnData, ``inplace=True``, and ``return_raw=False``. |
| 390 | + AnnData |
| 391 | + A modified copy when ``inplace=False`` and ``return_raw=False``. |
| 392 | + np.ndarray |
| 393 | + Centrality array when ``return_raw=True`` or ``X`` is a sparse |
| 394 | + matrix. |
| 395 | +
|
| 396 | + Updates AnnData |
| 397 | + --------------- |
| 398 | + adata.obs[key_added] : np.ndarray |
| 399 | + Per-cell centrality scores. |
| 400 | + """ |
| 401 | + valid_algorithms = {"coreness", "pagerank", "local_coreness", "local_pagerank"} |
| 402 | + if algorithm not in valid_algorithms: |
| 403 | + raise ValueError( |
| 404 | + f"Invalid algorithm '{algorithm}'. Must be one of {sorted(valid_algorithms)}." |
| 405 | + ) |
| 406 | + |
| 407 | + is_anndata = isinstance(X, AnnData) |
| 408 | + |
| 409 | + if is_anndata: |
| 410 | + if network_key not in X.obsp: |
| 411 | + raise ValueError(f"Network '{network_key}' not found. Run build_network first.") |
| 412 | + G = X.obsp[network_key] |
| 413 | + else: |
| 414 | + G = X |
| 415 | + |
| 416 | + n_cells = G.shape[1] |
| 417 | + |
| 418 | + if algorithm in ("local_coreness", "local_pagerank"): |
| 419 | + if labels is None: |
| 420 | + raise ValueError( |
| 421 | + f"'labels' is required when algorithm='{algorithm}'." |
| 422 | + ) |
| 423 | + |
| 424 | + # Resolve labels to a numeric assignment vector when needed. |
| 425 | + assignments: Optional[np.ndarray] = None |
| 426 | + if labels is not None: |
| 427 | + if isinstance(labels, str): |
| 428 | + if not is_anndata: |
| 429 | + raise ValueError( |
| 430 | + "`labels` must be an array when `X` is a sparse matrix, not a string key." |
| 431 | + ) |
| 432 | + if labels not in X.obs: |
| 433 | + raise ValueError(f"Labels column '{labels}' not found in adata.obs.") |
| 434 | + raw_labels = X.obs[labels].values |
| 435 | + else: |
| 436 | + raw_labels = np.asarray(labels) |
| 437 | + assignments = pd.Categorical(raw_labels).codes.astype(np.int32) |
| 438 | + |
| 439 | + # Clamp alpha for PageRank variants. |
| 440 | + if algorithm in ("pagerank", "local_pagerank"): |
| 441 | + alpha = float(np.clip(alpha, 0.0, 0.99)) |
| 442 | + |
| 443 | + # --- Compute centrality ------------------------------------------------ |
| 444 | + if algorithm == "coreness": |
| 445 | + centrality = np.asarray(_core.compute_coreness(G), dtype=np.float64) |
| 446 | + |
| 447 | + elif algorithm == "pagerank": |
| 448 | + uniform = np.full((n_cells, 1), 1.0 / n_cells) |
| 449 | + centrality = np.asarray( |
| 450 | + _core.compute_network_diffusion( |
| 451 | + G=G, X0=uniform, alpha=alpha, max_it=max_iter, |
| 452 | + thread_no=n_threads, approx=True, norm_method=0, tol=tol, |
| 453 | + ) |
| 454 | + ).ravel() |
| 455 | + |
| 456 | + elif algorithm == "local_coreness": |
| 457 | + centrality = np.asarray( |
| 458 | + _core.compute_archetype_centrality(G, assignments) |
| 459 | + ) |
| 460 | + |
| 461 | + elif algorithm == "local_pagerank": |
| 462 | + unique_codes = np.unique(assignments) |
| 463 | + n_groups = len(unique_codes) |
| 464 | + design = np.zeros((n_cells, n_groups), dtype=np.float64) |
| 465 | + for col_idx, code in enumerate(unique_codes): |
| 466 | + mask = assignments == code |
| 467 | + design[mask, col_idx] = 1.0 |
| 468 | + col_sums = design.sum(axis=0) |
| 469 | + col_sums[col_sums == 0] = 1.0 |
| 470 | + design /= col_sums |
| 471 | + |
| 472 | + design = np.ascontiguousarray(design) |
| 473 | + scores = np.asarray( |
| 474 | + _core.compute_network_diffusion( |
| 475 | + G=G, X0=design, alpha=alpha, max_it=max_iter, |
| 476 | + thread_no=n_threads, approx=True, norm_method=0, tol=tol, |
| 477 | + ) |
| 478 | + ) |
| 479 | + col_max = scores.max(axis=0) |
| 480 | + col_max[col_max == 0] = 1.0 |
| 481 | + scores /= col_max |
| 482 | + centrality = scores.max(axis=1) |
| 483 | + |
| 484 | + centrality = centrality.ravel() |
| 485 | + |
| 486 | + # --- Return / persist -------------------------------------------------- |
| 487 | + if not is_anndata or return_raw: |
| 488 | + return centrality |
| 489 | + |
| 490 | + adata = X if inplace else X.copy() |
| 491 | + if key_added is None: |
| 492 | + key_added = f"{algorithm}_{network_key}" |
| 493 | + persist_updates(adata, obs={key_added: centrality}) |
| 494 | + if not inplace: |
| 495 | + return adata |
| 496 | + return None |
| 497 | + |
| 498 | + |
322 | 499 | def layout_network( |
323 | 500 | adata: AnnData, |
324 | 501 | network_key: str = "actionet", |
|
0 commit comments