-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbootstrap.py
More file actions
62 lines (55 loc) · 2.36 KB
/
bootstrap.py
File metadata and controls
62 lines (55 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import Callable, Optional
import numpy as np
def bootstrap(
data: np.ndarray,
stat_fn: Callable[..., float],
*,
bootstrap: int = 1,
seed: Optional[int] = None,
) -> tuple[float, np.ndarray]:
"""Summarize data vector and, optionally, estimate confidence interval.
Input:
data: List[float]
One dimensional list containg data to be summarized.
stat_fn: Callable[...]
Function which would summarize the data. Function is required to
support two input arguments: positional argument accepting data
vector (or matrix), "axis" keyword argument which would tell along
which axis to apply the summarizing operation. E.g.: most of
statistical numpy functions (std, mean) could be passed without
any modification.
bootstrap: optional, int (default: 1)
Number of bootstrap samples to take when estimating confidence
interval. If bootstrap<=0, then bootstraping step is skipped and
no confidence interval is returned. If needed it would be wise to
take 1000 or more samples.
seed: optional, int (default: None)
Seed for the numpy random number generator.
Output:
Summary statistic (float)
Confidence interval (List[float], optional)
Examples:
```
>> import numpy as np
>> from utils import get_stat
>> np.random.seed(123)
>> x = np.random.randn(1000)
>> get_stat(x,
lambda data: np.std(data, ddof=1),
bootstrap=1000)
(1.001288306893338, array([0.96097997, 1.04474693]))
```
In this example we have defined stat_fn as a lambda function, because
we wanted to pass `ddof=1` to the `np.std` function. Otherwise
directly passing `np.std` to `get_stat` would be fine too:
```
>> get_stat(x, np.std, bootstrap=1000)
(1.0007875375162334, array([0.95658118, 1.04256825]))
```
"""
if bootstrap < 1:
bootstrap = 1
rng = np.random.default_rng(seed)
bootstrap_samples = rng.choice(data, size=(bootstrap, len(data)))
bootstrap_vals = stat_fn(bootstrap_samples, axis=1)
return stat_fn(data), np.percentile(bootstrap_vals, [2.5, 97.5])