-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistributions.py
More file actions
71 lines (52 loc) · 1.97 KB
/
distributions.py
File metadata and controls
71 lines (52 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
class Distribution:
def sample(self, size: int | tuple) -> np.ndarray:
raise NotImplementedError
class Gaussian(Distribution):
def __init__(self, mean: np.ndarray, cov: np.ndarray, rng: np.random.Generator):
self.mean = mean
self.cov = cov
self.rng = rng
def sample(self, size: int) -> np.ndarray:
return self.rng.multivariate_normal(self.mean, self.cov, size=size)
class EasyDistribution(Distribution):
def __init__(self, dim) -> None:
self.dim = dim
def sample(self, size: int | tuple) -> np.ndarray:
if type(size) == tuple:
samples = []
for _ in range(size[0]):
sample = []
for j in range(3):
if j == 0:
val = np.random.uniform(0, 1, size=self.dim)
elif j == 1:
val = np.random.uniform(1, 2, size=self.dim)
else:
val = np.random.uniform(10, 11, size=self.dim)
sample.append(val)
samples.append(sample)
else:
raise NotImplementedError
samples = np.array(samples)
return samples
class ThreeGaussian(Distribution):
def __init__(self, rng) -> None:
self.rng = rng
self.weights = [0.3, 0.4, 0.3]
self.means = np.array([[0, 0], [1, 1], [2, 0]])
self.covs = [np.eye(2) for _ in range(3)]
def sample(self, size) -> np.ndarray:
n_components = self.means.shape[0]
samples = []
for _ in range(size[0]):
sample = []
for j in range(3):
component = self.rng.choice(n_components, p=self.weights)
sample.append(
self.rng.multivariate_normal(
self.means[component], self.covs[component]
)
)
samples.append(sample)
return np.array(samples)