Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fast-geography-collisions.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up cloned household geography assignment for large local-area calibration builds.
33 changes: 16 additions & 17 deletions policyengine_us_data/calibration/clone_and_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,36 +296,35 @@ def _sample(size, mask_slice=None, fixed_slice=None):
# Clone 0: unrestricted draw
indices[:n_records] = _sample(n_records, extreme_mask, fixed_states)

assigned_cds = np.empty((n_clones, n_records), dtype=object)
assigned_cds[0] = cds[indices[:n_records]]
_, cd_codes = np.unique(cds, return_inverse=True)
cd_codes = cd_codes.astype(np.int32, copy=False)
record_positions = np.arange(n_records)
used_cd_by_record = np.zeros((n_records, cd_codes.max() + 1), dtype=bool)
used_cd_by_record[record_positions, cd_codes[indices[:n_records]]] = True

for clone_idx in range(1, n_clones):
start = clone_idx * n_records
clone_indices = _sample(n_records, extreme_mask, fixed_states)
clone_cds = cds[clone_indices]

collisions = np.zeros(n_records, dtype=bool)
for prev in range(clone_idx):
collisions |= clone_cds == assigned_cds[prev]
clone_cd_codes = cd_codes[clone_indices]
collisions = used_cd_by_record[record_positions, clone_cd_codes]

for _ in range(50):
n_bad = collisions.sum()
n_bad = int(collisions.sum())
if n_bad == 0:
break
bad_mask = collisions
if extreme_mask is not None and agi_probs is not None:
replacement = _sample(n_records, extreme_mask, fixed_states)
clone_indices[bad_mask] = replacement[bad_mask]
fixed_bad = fixed_states[bad_mask] if fixed_states is not None else None
replacement = _sample(n_bad, extreme_mask[bad_mask], fixed_bad)
else:
replacement = _sample(n_records, fixed_slice=fixed_states)
clone_indices[collisions] = replacement[collisions]
clone_cds = cds[clone_indices]
collisions = np.zeros(n_records, dtype=bool)
for prev in range(clone_idx):
collisions |= clone_cds == assigned_cds[prev]
fixed_bad = fixed_states[bad_mask] if fixed_states is not None else None
replacement = _sample(n_bad, fixed_slice=fixed_bad)
clone_indices[bad_mask] = replacement
clone_cd_codes = cd_codes[clone_indices]
collisions = used_cd_by_record[record_positions, clone_cd_codes]

indices[start : start + n_records] = clone_indices
assigned_cds[clone_idx] = clone_cds
used_cd_by_record[record_positions, clone_cd_codes] = True

assigned_blocks = blocks[indices]
return GeographyAssignment(
Expand Down
Loading