Skip to content

Commit 6804d93

Browse files
Add files via upload
1 parent c6bfdf6 commit 6804d93

1 file changed

Lines changed: 50 additions & 10 deletions

File tree

  • discretesampling/base/algorithms

discretesampling/base/algorithms/smc.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
from discretesampling.base.algorithms.smc_components.normalisation import normalise
88
from discretesampling.base.algorithms.smc_components.effective_sample_size import ess
99
from discretesampling.base.algorithms.smc_components.resampling import systematic_resampling
10+
from discretesampling.base.algorithms.smc_components.knapsack_resampling import knapsack_resampling
11+
from discretesampling.base.algorithms.smc_components.minError_ImportanceResampling import min_error_continuous_state_resampling, min_error_importance_resampling
12+
from discretesampling.base.algorithms.smc_components.variational_resampling import kl
13+
from discretesampling.base.algorithms.smc_components.importance_resampling_version3 import importance_resampling_v3
14+
from discretesampling.base.algorithms.smc_components.residual_resampling import residual_resampling
15+
16+
1017

1118

1219
class DiscreteVariableSMC():
@@ -37,26 +44,57 @@ def __init__(self, variableType, target, initialProposal, proposal=None,
3744
self.initialProposal = initialProposal
3845
self.target = target
3946

40-
def sample(self, Tsmc, N, seed=0, verbose=True):
47+
def sample(self, Tsmc, N,a, resampling, seed=0, verbose=True):
48+
4149
loc_n = int(N/self.exec.P)
50+
4251
rank = self.exec.rank
4352
mvrs_rng = RNG(seed)
4453
rngs = [RNG(i + rank*loc_n + 1 + seed) for i in range(loc_n)] # RNG for each particle
4554

4655
initialParticles = [self.initialProposal.sample(rngs[i], self.target) for i in range(loc_n)]
4756
current_particles = initialParticles
48-
logWeights = np.array([self.target.eval(p) - self.initialProposal.eval(p, self.target) for p in initialParticles])
57+
58+
logWeights = np.array([self.target.eval(p)[0] - self.initialProposal.eval(p, self.target) for p in initialParticles])
4959

5060
display_progress_bar = verbose and rank == 0
5161
progress_bar = tqdm(total=Tsmc, desc="SMC sampling", disable=not display_progress_bar)
52-
5362
for t in range(Tsmc):
63+
tot_new_possibilities_for_predictions = []
5464
logWeights = normalise(logWeights, self.exec)
5565
neff = ess(logWeights, self.exec)
5666

5767
if math.log(neff) < math.log(N) - math.log(2):
58-
current_particles, logWeights = systematic_resampling(
59-
current_particles, logWeights, mvrs_rng, exec=self.exec)
68+
69+
70+
if (resampling == "systematic"):
71+
current_particles, logWeights = systematic_resampling(
72+
current_particles, logWeights, mvrs_rng, exec=self.exec)
73+
74+
elif (resampling == "knapsack"):
75+
current_particles, logWeights, _ = knapsack_resampling(
76+
current_particles, np.exp(logWeights), mvrs_rng)
77+
78+
elif (resampling == "min_error"):
79+
current_particles, logWeights, _ = min_error_continuous_state_resampling(
80+
current_particles, np.exp(logWeights), mvrs_rng, N)
81+
82+
elif (resampling == "variational"):
83+
new_ancestors, logWeights = kl(logWeights)
84+
current_particles = np.array(current_particles)[new_ancestors].tolist()
85+
86+
elif (resampling == "min_error_imp"):
87+
current_particles, logWeights= min_error_importance_resampling(
88+
current_particles, np.exp(logWeights), mvrs_rng, N)
89+
90+
elif (resampling == "CIR"):
91+
current_particles, logWeights= importance_resampling_v3(
92+
current_particles, np.exp(logWeights), mvrs_rng, N)
93+
94+
elif (resampling == "residual"):
95+
current_particles, logWeights= residual_resampling(
96+
current_particles, np.exp(logWeights), mvrs_rng, N)
97+
6098

6199
new_particles = copy.copy(current_particles)
62100

@@ -65,7 +103,7 @@ def sample(self, Tsmc, N, seed=0, verbose=True):
65103
# Sample new particles and calculate forward probabilities
66104
for i in range(loc_n):
67105
forward_proposal = self.proposal
68-
new_particles[i] = forward_proposal.sample(current_particles[i], rng=rngs[i])
106+
new_particles[i] = forward_proposal.sample(current_particles[i], a, rng=rngs[i])
69107
forward_logprob[i] = forward_proposal.eval(current_particles[i], new_particles[i])
70108

71109
if self.use_optimal_L:
@@ -79,13 +117,15 @@ def sample(self, Tsmc, N, seed=0, verbose=True):
79117
Lkernel = self.Lkernel
80118
reverse_logprob = Lkernel.eval(new_particles[i], current_particles[i])
81119

82-
current_target_logprob = self.target.eval(current_particles[i])
83-
new_target_logprob = self.target.eval(new_particles[i])
120+
current_target_logprob, current_possibilities_for_predictions = self.target.eval(current_particles[i])
121+
new_target_logprob, new_possibilities_for_predictions = self.target.eval(new_particles[i])
84122

85123
logWeights[i] += new_target_logprob - current_target_logprob + reverse_logprob - forward_logprob[i]
86-
124+
125+
tot_new_possibilities_for_predictions.append(new_possibilities_for_predictions)
126+
#if t<Tsmc:
87127
current_particles = new_particles
88128
progress_bar.update(1)
89129

90130
progress_bar.close()
91-
return current_particles
131+
return current_particles,tot_new_possibilities_for_predictions, logWeights

0 commit comments

Comments
 (0)