Skip to content

Commit aaa9bac

Browse files
feat: add flag to disable gathering of particles in smc
1 parent ff61ee4 commit aaa9bac

File tree

1 file changed

+4
-3
lines changed
  • discretesampling/base/algorithms

1 file changed

+4
-3
lines changed

discretesampling/base/algorithms/smc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, variableType, target, initialProposal,
2929
self.initialProposal = initialProposal
3030
self.target = target
3131

32-
def sample(self, Tsmc, N, seed=0):
32+
def sample(self, Tsmc, N, seed=0, gather_results=True):
3333
loc_n = int(N/self.exec.P)
3434
rank = self.exec.rank
3535
mvrs_rng = RNG(seed)
@@ -75,8 +75,9 @@ def sample(self, Tsmc, N, seed=0):
7575

7676
current_particles = new_particles
7777

78-
current_particles = self.exec.gather_all(current_particles)
79-
logWeights = self.exec.gather(logWeights, [N])
78+
if gather_results:
79+
current_particles = self.exec.gather_all(current_particles)
80+
logWeights = self.exec.gather(logWeights, [N])
8081
results = SMCOutput(current_particles, logWeights)
8182

8283
return results

0 commit comments

Comments
 (0)