Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 141 additions & 27 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,20 @@
from functools import wraps
from inspect import getcallargs
from textwrap import dedent, fill
from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union, Callable
from typing import (
IO,
Dict,
Hashable,
List,
Mapping,
Optional,
Tuple,
Union,
Callable,
Any,
NoReturn,
Sequence,
)
from urllib.parse import unquote_plus
from numpy.testing import assert_allclose, assert_array_equal

Expand Down Expand Up @@ -874,53 +887,114 @@ def _simple_xarray_concat(
# )


def _da_concat(arrays: List[da.Array], **kwargs) -> da.Array:
def _da_concat(arrays: Sequence[da.Array], **kwargs: Any) -> da.Array:
"""Concatenate dask arrays, avoiding graph growth if only one array.

Parameters
----------
arrays : sequence of dask.array.Array
The arrays to concatenate.
**kwargs : Any
Keyword arguments passed through to `dask.array.concatenate`.

Returns
-------
dask.array.Array
The concatenated array.

Notes
-----
Expects all arrays are compatible shapes along non-concatenated axes.
"""
if len(arrays) == 1:
return arrays[0]
else:
return da.concatenate(arrays, **kwargs)


def _value_error(
name,
value,
expectation,
):
name: str,
value: object,
expectation: str,
) -> NoReturn:
"""Raise a ValueError with a standardized error message.

Parameters
----------
name : str
The name of the parameter with the bad value.
value : object
The invalid value that was provided.
expectation : str
Description of what was expected.

Returns
-------
None
This function does not return; it always raises an exception.

Raises
------
ValueError
Always raised with the formatted message.
"""
message = f"Bad value for parameter {name}; expected {expectation}, found {value!r}"
raise ValueError(message)


def _hash_params(params):
"""Helper function to hash function parameters."""
def _hash_params(params: Any) -> Tuple[str, str]:
"""Helper function to hash function parameters.

Parameters
----------
params : Any
The parameters to hash. Must be JSON-serializable.

Returns
-------
tuple of (str, str)
A tuple containing the MD5 hex digest and the JSON string representation.

Notes
-----
Uses `sort_keys=True` and `indent=4` so the same semantic object yields
a stable JSON string and hash (assuming deterministic JSON serialization).
"""
s = json.dumps(params, sort_keys=True, indent=4)
h = hashlib.md5(s.encode()).hexdigest()
return h, s


def _jitter(a, fraction, random_state=np.random):
"""Jitter data by adding uniform noise scaled by the data range.
def _jitter(
a: Union[np.ndarray, pd.Series],
fraction: float,
random_state=np.random,
) -> Union[np.ndarray, pd.Series]:
"""Jitter data in `a` using the fraction `fraction`.

Parameters
----------
a : array-like
Input data to jitter. Can be a numpy array or pandas Series.
a : ndarray or pandas.Series
The input array or Series to add jitter to.
fraction : float
Controls the amplitude of the jitter relative to the data range.
The fractional amount of the data range to use as the maximum noise scale.
random_state : numpy.random.Generator or module, optional
Random number generator to use. Accepts a ``numpy.random.Generator``
(from ``np.random.default_rng()``) or the ``numpy.random`` module.
Defaults to ``np.random`` (global RNG) for backward compatibility.

Returns
-------
array-like
Jittered copy of the input data with the same shape and type.
ndarray or pandas.Series
The input data with added uniform noise.

Notes
-----
The noise scale is calculated as `(a.max() - a.min()) * fraction`.
If the range is 0 (all values are identical), the function returns `a` unchanged
(because adding uniform(-0, 0) adds 0).
Prefer passing a local ``np.random.default_rng(seed=...)`` to avoid
mutating global RNG state and to ensure reproducibility.

"""
r = a.max() - a.min()
return a + fraction * random_state.uniform(-r, r, a.shape)
Expand Down Expand Up @@ -992,31 +1066,36 @@ def set_level(self, level):
self._handler.setLevel(level)


def _jackknife_ci(stat_data, jack_stat, confidence_level):
def _jackknife_ci(
stat_data: float,
jack_stat: Any,
confidence_level: float,
) -> Tuple[float, float, float, float, float, float]:
"""Compute a confidence interval from jackknife resampling.

Parameters
----------
stat_data : scalar
stat_data : float
Value of the statistic computed on all data.
jack_stat : ndarray
jack_stat : array-like
Values of the statistic computed for each jackknife resample.
Can be a 1D ndarray or a list of floats.
confidence_level : float
Desired confidence level (e.g., 0.95).
Desired confidence level (e.g., 0.95). Must be in the interval (0, 1).

Returns
-------
estimate
Bias-corrected "jackknifed estimate".
bias
estimate : float
Bias-corrected jackknifed estimate.
bias : float
Jackknife bias.
std_err
std_err : float
Standard error.
ci_err
ci_err : float
Size of the confidence interval.
ci_low
ci_low : float
Lower limit of confidence interval.
ci_upp
ci_upp : float
Upper limit of confidence interval.

Notes
Expand Down Expand Up @@ -1232,6 +1311,24 @@ def check_types_wrapper(*args, **kwargs):

@numba.njit
def _true_runs(a):
"""Find contiguous runs of True values in a boolean array.

Parameters
----------
a : ndarray
A 1D array-like object. Truthy values are treated as `True`.

Returns
-------
starts : ndarray
A 1D array of int64 containing the inclusive start indices of each run.
stops : ndarray
A 1D array of int64 containing the exclusive stop indices of each run.

Notes
-----
The returned `starts` and `stops` define half-open intervals `[start, stop)`.
"""
in_run = False
starts = []
stops = []
Expand All @@ -1250,6 +1347,23 @@ def _true_runs(a):

@numba.njit(parallel=True)
def _pdist_abs_hamming(X):
"""Calculate the pairwise absolute Hamming distance between rows.

Parameters
----------
X : ndarray
A 2D array of shape `(n_obs, n_ftr)` containing the observations.

Returns
-------
out : ndarray
A 2D array of shape `(n_obs, n_obs)` and dtype `int32`.
`out[i, j]` is the number of positions where `X[i, k] != X[j, k]`.

Notes
-----
Performance: `O(n_obs^2 * n_ftr)`. Uses Numba parallelization.
"""
n_obs = X.shape[0]
n_ftr = X.shape[1]
out = np.zeros((n_obs, n_obs), dtype=np.int32)
Expand Down
Loading