77from discretesampling .base .algorithms .smc_components .normalisation import normalise
88from discretesampling .base .algorithms .smc_components .effective_sample_size import ess
99from discretesampling .base .algorithms .smc_components .resampling import systematic_resampling
10+ from discretesampling .base .algorithms .smc_components .knapsack_resampling import knapsack_resampling
11+ from discretesampling .base .algorithms .smc_components .minError_ImportanceResampling import min_error_continuous_state_resampling , min_error_importance_resampling
12+ from discretesampling .base .algorithms .smc_components .variational_resampling import kl
13+ from discretesampling .base .algorithms .smc_components .importance_resampling_version3 import importance_resampling_v3
14+ from discretesampling .base .algorithms .smc_components .residual_resampling import residual_resampling
15+
16+
1017
1118
1219class DiscreteVariableSMC ():
@@ -37,26 +44,57 @@ def __init__(self, variableType, target, initialProposal, proposal=None,
3744 self .initialProposal = initialProposal
3845 self .target = target
3946
40- def sample (self , Tsmc , N , seed = 0 , verbose = True ):
47+ def sample (self , Tsmc , N ,a , resampling , seed = 0 , verbose = True ):
48+
4149 loc_n = int (N / self .exec .P )
50+
4251 rank = self .exec .rank
4352 mvrs_rng = RNG (seed )
4453 rngs = [RNG (i + rank * loc_n + 1 + seed ) for i in range (loc_n )] # RNG for each particle
4554
4655 initialParticles = [self .initialProposal .sample (rngs [i ], self .target ) for i in range (loc_n )]
4756 current_particles = initialParticles
48- logWeights = np .array ([self .target .eval (p ) - self .initialProposal .eval (p , self .target ) for p in initialParticles ])
57+
58+ logWeights = np .array ([self .target .eval (p )[0 ] - self .initialProposal .eval (p , self .target ) for p in initialParticles ])
4959
5060 display_progress_bar = verbose and rank == 0
5161 progress_bar = tqdm (total = Tsmc , desc = "SMC sampling" , disable = not display_progress_bar )
52-
5362 for t in range (Tsmc ):
63+ tot_new_possibilities_for_predictions = []
5464 logWeights = normalise (logWeights , self .exec )
5565 neff = ess (logWeights , self .exec )
5666
5767 if math .log (neff ) < math .log (N ) - math .log (2 ):
58- current_particles , logWeights = systematic_resampling (
59- current_particles , logWeights , mvrs_rng , exec = self .exec )
68+
69+
70+ if (resampling == "systematic" ):
71+ current_particles , logWeights = systematic_resampling (
72+ current_particles , logWeights , mvrs_rng , exec = self .exec )
73+
74+ elif (resampling == "knapsack" ):
75+ current_particles , logWeights , _ = knapsack_resampling (
76+ current_particles , np .exp (logWeights ), mvrs_rng )
77+
78+ elif (resampling == "min_error" ):
79+ current_particles , logWeights , _ = min_error_continuous_state_resampling (
80+ current_particles , np .exp (logWeights ), mvrs_rng , N )
81+
82+ elif (resampling == "variational" ):
83+ new_ancestors , logWeights = kl (logWeights )
84+ current_particles = np .array (current_particles )[new_ancestors ].tolist ()
85+
86+ elif (resampling == "min_error_imp" ):
87+ current_particles , logWeights = min_error_importance_resampling (
88+ current_particles , np .exp (logWeights ), mvrs_rng , N )
89+
90+ elif (resampling == "CIR" ):
91+ current_particles , logWeights = importance_resampling_v3 (
92+ current_particles , np .exp (logWeights ), mvrs_rng , N )
93+
94+ elif (resampling == "residual" ):
95+ current_particles , logWeights = residual_resampling (
96+ current_particles , np .exp (logWeights ), mvrs_rng , N )
97+
6098
6199 new_particles = copy .copy (current_particles )
62100
@@ -65,7 +103,7 @@ def sample(self, Tsmc, N, seed=0, verbose=True):
65103 # Sample new particles and calculate forward probabilities
66104 for i in range (loc_n ):
67105 forward_proposal = self .proposal
68- new_particles [i ] = forward_proposal .sample (current_particles [i ], rng = rngs [i ])
106+ new_particles [i ] = forward_proposal .sample (current_particles [i ], a , rng = rngs [i ])
69107 forward_logprob [i ] = forward_proposal .eval (current_particles [i ], new_particles [i ])
70108
71109 if self .use_optimal_L :
@@ -79,13 +117,15 @@ def sample(self, Tsmc, N, seed=0, verbose=True):
79117 Lkernel = self .Lkernel
80118 reverse_logprob = Lkernel .eval (new_particles [i ], current_particles [i ])
81119
82- current_target_logprob = self .target .eval (current_particles [i ])
83- new_target_logprob = self .target .eval (new_particles [i ])
120+ current_target_logprob , current_possibilities_for_predictions = self .target .eval (current_particles [i ])
121+ new_target_logprob , new_possibilities_for_predictions = self .target .eval (new_particles [i ])
84122
85123 logWeights [i ] += new_target_logprob - current_target_logprob + reverse_logprob - forward_logprob [i ]
86-
124+
125+ tot_new_possibilities_for_predictions .append (new_possibilities_for_predictions )
126+ #if t<Tsmc:
87127 current_particles = new_particles
88128 progress_bar .update (1 )
89129
90130 progress_bar .close ()
91- return current_particles
131+ return current_particles , tot_new_possibilities_for_predictions , logWeights
0 commit comments