Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ filterwarnings = [
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
# everybody uses this zarr 3 feature, including us, XArray, lots of data out there …
"ignore:Consolidated metadata is currently not part:UserWarning",
# joblib fallback to serial mode in restricted multiprocessing environments
"ignore:.*joblib will operate in serial mode:UserWarning",
]

[tool.coverage]
Expand Down
126 changes: 99 additions & 27 deletions src/scanpy/_utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import TYPE_CHECKING

import numpy as np
from sklearn.utils import check_random_state

from . import ensure_igraph

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Self

from numpy.typing import NDArray

Expand All @@ -21,8 +21,10 @@
"RNGLike",
"SeedLike",
"_LegacyRandom",
"_if_legacy_apply_global",
"accepts_legacy_random_state",
"ith_k_tuple",
"legacy_numpy_gen",
"legacy_random_state",
"random_k_tuples",
"random_str",
]
Expand All @@ -43,29 +45,29 @@ class _RNGIgraph:
See :func:`igraph.set_random_number_generator` for the requirements.
"""

def __init__(self, random_state: int | np.random.RandomState = 0) -> None:
self._rng = check_random_state(random_state)
def __init__(self, rng: SeedLike | RNGLike | None) -> None:
self._rng = np.random.default_rng(rng)

def getrandbits(self, k: int) -> int:
return self._rng.tomaxint() & ((1 << k) - 1)
lims = np.iinfo(np.uint64)
i = int(self._rng.integers(0, lims.max, dtype=np.uint64))
return i & ((1 << k) - 1)

def randint(self, a: int, b: int) -> int:
return self._rng.randint(a, b + 1)
def randint(self, a: int, b: int) -> np.int64:
return self._rng.integers(a, b + 1)

def __getattr__(self, attr: str):
return getattr(self._rng, "normal" if attr == "gauss" else attr)


@contextmanager
def set_igraph_random_state(
random_state: int | np.random.RandomState,
) -> Generator[None, None, None]:
def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]:
ensure_igraph()
import igraph

rng = _RNGIgraph(random_state)
ig_rng = _RNGIgraph(rng)
try:
igraph.set_random_number_generator(rng)
igraph.set_random_number_generator(ig_rng)
yield None
finally:
igraph.set_random_number_generator(random)
Expand All @@ -76,26 +78,42 @@ def set_igraph_random_state(
###################################


def legacy_numpy_gen(
random_state: _LegacyRandom | None = None,
) -> np.random.Generator:
"""Return a random generator that behaves like the legacy one."""
if random_state is not None:
if isinstance(random_state, np.random.RandomState):
np.random.set_state(random_state.get_state(legacy=False))
return _FakeRandomGen(random_state)
np.random.seed(random_state)
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))


class _FakeRandomGen(np.random.Generator):
_arg: _LegacyRandom
_state: np.random.RandomState

def __init__(self, random_state: np.random.RandomState) -> None:
self._state = random_state
def __init__(
self, arg: _LegacyRandom, state: np.random.RandomState | None = None
) -> None:
self._arg = arg
self._state = np.random.RandomState(arg) if state is None else state
super().__init__(self._state._bit_generator)

@classmethod
def wrap_global(
cls,
arg: _LegacyRandom = None,
state: np.random.RandomState | None = None,
) -> Self:
"""Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions."""
if arg is not None:
if isinstance(arg, np.random.RandomState):
np.random.set_state(arg.get_state(legacy=False))
return _FakeRandomGen(arg, state)
np.random.seed(arg)
return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator()))

def __eq__(self, other: object) -> bool:
if not isinstance(other, _FakeRandomGen):
return False
return self._arg == other._arg

def __hash__(self) -> int:
return hash((type(self), self._arg))

@classmethod
def _delegate(cls) -> None:
names = dict(integers="randint")
for name, meth in np.random.Generator.__dict__.items():
if name.startswith("_") or not callable(meth):
continue
Expand All @@ -108,12 +126,66 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs):

return wrapper

setattr(cls, name, mk_wrapper(name, meth))
setattr(cls, names.get(name, name), mk_wrapper(name, meth))


_FakeRandomGen._delegate()


def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator:
"""Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`.

This resets the global legacy RNG from the original `_arg` and returns a
generator which continues drawing from the same internal state.
"""
if not isinstance(rng, _FakeRandomGen):
return rng

return _FakeRandomGen.wrap_global(rng._arg, rng._state)


def legacy_random_state(
rng: SeedLike | RNGLike | None, *, always_state: bool = False
) -> _LegacyRandom:
"""Convert a np.random.Generator into a legacy `random_state` argument.

If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute.
"""
if isinstance(rng, _FakeRandomGen):
return rng._state if always_state else rng._arg
rng = np.random.default_rng(rng)
return np.random.RandomState(rng.bit_generator.spawn(1)[0])


def accepts_legacy_random_state[**P, R](
random_state_default: _LegacyRandom,
) -> callable[[callable[P, R]], callable[P, R]]:
"""Make a function accept `random_state: _LegacyRandom` and pass it as `rng`.

If the decorated function is called with a `random_state` argument,
it’ll be wrapped in a :class:`_FakeRandomGen`.
Passing both ``rng`` and ``random_state`` at the same time is an error.
If neither is given, ``random_state_default`` is used.
"""

def decorator(func: callable[P, R]) -> callable[P, R]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
match "random_state" in kwargs, "rng" in kwargs:
case True, True:
msg = "Specify at most one of `rng` and `random_state`."
raise TypeError(msg)
case True, False:
kwargs["rng"] = _FakeRandomGen(kwargs.pop("random_state"))
case False, False:
kwargs["rng"] = _FakeRandomGen(random_state_default)
return func(*args, **kwargs)

return wrapper

return decorator


###################
# Random k-tuples #
###################
Expand Down
10 changes: 6 additions & 4 deletions src/scanpy/datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from .._compat import deprecated, old_positionals
from .._settings import settings
from .._utils._doctests import doctest_internet, doctest_needs
from .._utils.random import accepts_legacy_random_state, legacy_random_state
from ..readwrite import read, read_h5ad, read_visium
from ._utils import check_datasetdir_exists

if TYPE_CHECKING:
from typing import Literal

from .._utils.random import _LegacyRandom
from .._utils.random import RNGLike, SeedLike

type VisiumSampleID = Literal[
"V1_Breast_Cancer_Block_A_Section_1",
Expand Down Expand Up @@ -57,13 +58,14 @@
@old_positionals(
"n_variables", "n_centers", "cluster_std", "n_observations", "random_state"
)
@accepts_legacy_random_state(0)
def blobs(
*,
n_variables: int = 11,
n_centers: int = 5,
cluster_std: float = 1.0,
n_observations: int = 640,
random_state: _LegacyRandom = 0,
rng: SeedLike | RNGLike | None = None,
) -> AnnData:
"""Gaussian Blobs.

Expand All @@ -78,7 +80,7 @@ def blobs(
n_observations
Number of observations. By default, this is the same observation number
as in :func:`scanpy.datasets.krumsiek11`.
random_state
rng
Determines random number generation for dataset creation.

Returns
Expand All @@ -101,7 +103,7 @@ def blobs(
n_features=n_variables,
centers=n_centers,
cluster_std=cluster_std,
random_state=random_state,
random_state=legacy_random_state(rng),
)
return AnnData(x, obs=dict(blobs=y.astype(str)))

Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/experimental/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
doc_pca_chunk = """\
n_comps
Number of principal components to compute in the PCA step.
random_state
Random seed for setting the initial states for the optimization in the PCA step.
rng
Random number generator for setting the initial states for the optimization in the PCA step.
kwargs_pca
Dictionary of further keyword arguments passed on to `scanpy.pp.pca()`.
"""
Expand Down
7 changes: 5 additions & 2 deletions src/scanpy/experimental/pp/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ... import logging as logg
from ..._compat import CSBase, warn
from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual
from ..._utils.random import accepts_legacy_random_state
from ...experimental._docs import (
doc_adata,
doc_check_values,
Expand All @@ -27,6 +28,7 @@
from typing import Any

from ..._utils import Empty
from ..._utils.random import RNGLike, SeedLike


def _pearson_residuals(
Expand Down Expand Up @@ -160,13 +162,14 @@ def normalize_pearson_residuals(
check_values=doc_check_values,
inplace=doc_inplace,
)
@accepts_legacy_random_state(0)
def normalize_pearson_residuals_pca(
adata: AnnData,
*,
theta: float = 100,
clip: float | None = None,
n_comps: int | None = 50,
random_state: float = 0,
rng: SeedLike | RNGLike | None = None,
kwargs_pca: Mapping[str, Any] = MappingProxyType({}),
mask_var: np.ndarray | str | None | Empty = _empty,
use_highly_variable: bool | None = None,
Expand Down Expand Up @@ -233,7 +236,7 @@ def normalize_pearson_residuals_pca(
normalize_pearson_residuals(
adata_pca, theta=theta, clip=clip, check_values=check_values
)
pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca)
pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca)
n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None

if inplace:
Expand Down
9 changes: 7 additions & 2 deletions src/scanpy/experimental/pp/_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
)
from scanpy.preprocessing import pca

from ..._utils.random import accepts_legacy_random_state

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any

import pandas as pd
from anndata import AnnData

from ..._utils.random import RNGLike, SeedLike


@_doc_params(
adata=doc_adata,
Expand All @@ -33,6 +37,7 @@
check_values=doc_check_values,
inplace=doc_inplace,
)
@accepts_legacy_random_state(0)
def recipe_pearson_residuals( # noqa: PLR0913
adata: AnnData,
*,
Expand All @@ -42,7 +47,7 @@ def recipe_pearson_residuals( # noqa: PLR0913
batch_key: str | None = None,
chunksize: int = 1000,
n_comps: int | None = 50,
random_state: float | None = 0,
rng: SeedLike | RNGLike | None = None,
kwargs_pca: Mapping[str, Any] = MappingProxyType({}),
check_values: bool = True,
inplace: bool = True,
Expand Down Expand Up @@ -133,7 +138,7 @@ def recipe_pearson_residuals( # noqa: PLR0913
experimental.pp.normalize_pearson_residuals(
adata_pca, theta=theta, clip=clip, check_values=check_values
)
pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca)
pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca)

if inplace:
normalization_param = adata_pca.uns["pearson_residuals_normalization"]
Expand Down
Loading
Loading