Skip to content

Commit ecfd2ff

Browse files
feat: add flag to disable gathering of particles in smc
1 parent 69955a8 commit ecfd2ff

File tree

1 file changed

+5
-3
lines changed
  • discretesampling/base/algorithms

1 file changed

+5
-3
lines changed

discretesampling/base/algorithms/smc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, variableType, target, initialProposal,
3030
self.initialProposal = initialProposal
3131
self.target = target
3232

33-
def sample(self, Tsmc, N, seed=0, verbose=True):
33+
def sample(self, Tsmc, N, seed=0, gather_results=True, verbose=True):
3434
loc_n = int(N/self.exec.P)
3535
rank = self.exec.rank
3636
mvrs_rng = RNG(seed)
@@ -80,8 +80,10 @@ def sample(self, Tsmc, N, seed=0, verbose=True):
8080
current_particles = new_particles
8181
progress_bar.update(1)
8282

83-
current_particles = self.exec.gather_all(current_particles)
84-
logWeights = self.exec.gather(logWeights, [N])
83+
if gather_results:
84+
current_particles = self.exec.gather_all(current_particles)
85+
logWeights = self.exec.gather(logWeights, [N])
86+
8587
results = SMCOutput(current_particles, logWeights)
8688
progress_bar.close()
8789
return results

0 commit comments

Comments
 (0)