-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexample_bernoulli_bandit.py
More file actions
49 lines (39 loc) · 1.47 KB
/
example_bernoulli_bandit.py
File metadata and controls
49 lines (39 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Parallel-Bayesian-Optimization-Thompson-Sampling is free software:
# you can redistribute it and/or modifyit under the terms of the MIT
# License. You should have received a copy of he MIT License along with
# Parallel-Bayesian-Optimization-Thompson-Sampling.
# If not, see <https://opensource.org/licenses/MIT>.
#
import sys
import numpy as np
from thompson_sampling import ThompsonSampling, BernoulliThompsonStrategy
sys.path.append("..")
# Example Bernolli Bandit
# ------------
# In this example we see how to use Thompson Sampling
# to solve the Bernoulli Bandit problem
def fitness(arm: int):
bernoulli_parameters = [0.1, 0.3, 0.9, 0.4]
if np.random.rand() < bernoulli_parameters[arm]:
return 1, arm
else:
return 0, arm
def on_update(epoch, strategy, last_epoch):
if epoch % 10 == 0:
print("Epoch {} - Parameters mean {}".format(epoch, strategy.mean))
if np.any(strategy.mean > 0.9) or last_epoch:
print("Search is over!")
print("Epoch {} - Parameters mean {}".format(epoch, strategy.mean))
print("Epoch {} - Number of times an arm as selected {}".format(epoch, strategy.selections))
return True
return False
if __name__ == "__main__":
strategy = BernoulliThompsonStrategy(4, 1, 1)
ts = ThompsonSampling(
thompson_strategy=strategy,
epochs=100,
fitness_function=fitness,
callbacks={"on_update": on_update},
num_processors=1,
)
strategy = ts.run()