Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/964.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove temporary CPS take-up source anchors from persisted H5 outputs and add an enhanced-CPS-only data build path.
76 changes: 50 additions & 26 deletions modal_app/data_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def snapshot(self) -> dict[str, int]:
# enhanced_cps.py produces both the dataset and calibration log
"policyengine_us_data/datasets/cps/enhanced_cps.py": [
"policyengine_us_data/storage/enhanced_cps_2024.h5",
"policyengine_us_data/storage/enhanced_cps_2024.clone_diagnostics.json",
"calibration_log.csv",
],
"policyengine_us_data/calibration/create_stratified_cps.py": (
Expand Down Expand Up @@ -317,12 +318,15 @@ def validate_and_maybe_upload_datasets(
upload: bool,
skip_enhanced_cps: bool,
env: dict,
require_small_enhanced_cps: bool = True,
stage_only: bool = False,
run_id: str = "",
) -> None:
validation_args = ["--validate-only"]
if skip_enhanced_cps:
validation_args.append("--no-require-enhanced-cps")
elif not require_small_enhanced_cps:
validation_args.append("--no-require-small-enhanced-cps")

print("=== Validating built datasets ===")
run_script(
Expand All @@ -335,6 +339,8 @@ def validate_and_maybe_upload_datasets(
upload_args = []
if skip_enhanced_cps:
upload_args.append("--no-require-enhanced-cps")
elif not require_small_enhanced_cps:
upload_args.append("--no-require-small-enhanced-cps")
if stage_only:
upload_args.append("--stage-only")
if run_id:
Expand Down Expand Up @@ -504,6 +510,7 @@ def write_dataset_build_contract(
upload_requested: bool,
stage_only: bool,
skip_enhanced_cps: bool,
skip_stage_5: bool = False,
) -> StageContract:
"""Write the Stage 1 semantic handoff contract next to copied artifacts."""
contract = build_dataset_build_output_contract(
Expand All @@ -518,6 +525,7 @@ def write_dataset_build_contract(
upload_requested=upload_requested,
stage_only=stage_only,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
)
write_contract(
contract,
Expand Down Expand Up @@ -559,6 +567,7 @@ def build_datasets(
clear_checkpoints: bool = False,
skip_tests: bool = False,
skip_enhanced_cps: bool = False,
skip_stage_5: bool = False,
stage_only: bool = False,
run_id: str = "",
):
Expand All @@ -572,6 +581,8 @@ def build_datasets(
skip_tests: Skip running the test suite (useful for calibration runs).
skip_enhanced_cps: Skip enhanced_cps.py and small_enhanced_cps.py
(useful for calibration runs that only need source_imputed H5).
skip_stage_5: Skip source-imputed CPS and small enhanced CPS after
enhanced_cps_2024.h5 is built.
stage_only: Upload to HF staging only, without promoting a release.
"""
setup_gcp_credentials()
Expand Down Expand Up @@ -645,6 +656,12 @@ def build_datasets(

if sequential:
for script, output in SCRIPT_OUTPUTS.items():
if skip_stage_5 and script in (
"policyengine_us_data/calibration/create_source_imputed_cps.py",
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
):
print(f"Skipping {script} (--skip-stage-5)")
continue
if skip_enhanced_cps and script in (
"policyengine_us_data/datasets/cps/enhanced_cps.py",
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
Expand Down Expand Up @@ -761,33 +778,21 @@ def build_datasets(
# GROUP 4: After Phase 4 - run in parallel
# create_source_imputed_cps needs stratified_cps
# small_enhanced_cps needs enhanced_cps
print(
"=== Phase 5: Building source imputed CPS "
"and small enhanced CPS (parallel) ==="
)
phase5_futures = []
with ThreadPoolExecutor(max_workers=2) as executor:
phase5_futures.append(
executor.submit(
run_script_with_checkpoint,
"policyengine_us_data/calibration/create_source_imputed_cps.py",
SCRIPT_OUTPUTS[
"policyengine_us_data/calibration/create_source_imputed_cps.py"
],
branch,
checkpoint_volume,
env=env,
log_file=log_file,
checkpoint_stats=checkpoint_stats,
)
if skip_stage_5:
print("Skipping Phase 5 (--skip-stage-5)")
else:
print(
"=== Phase 5: Building source imputed CPS "
"and small enhanced CPS (parallel) ==="
)
if not skip_enhanced_cps:
phase5_futures = []
with ThreadPoolExecutor(max_workers=2) as executor:
phase5_futures.append(
executor.submit(
run_script_with_checkpoint,
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
"policyengine_us_data/calibration/create_source_imputed_cps.py",
SCRIPT_OUTPUTS[
"policyengine_us_data/datasets/cps/small_enhanced_cps.py"
"policyengine_us_data/calibration/create_source_imputed_cps.py"
],
branch,
checkpoint_volume,
Expand All @@ -796,10 +801,25 @@ def build_datasets(
checkpoint_stats=checkpoint_stats,
)
)
else:
print("Skipping small_enhanced_cps.py (--skip-enhanced-cps)")
for future in as_completed(phase5_futures):
future.result()
if not skip_enhanced_cps:
phase5_futures.append(
executor.submit(
run_script_with_checkpoint,
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
SCRIPT_OUTPUTS[
"policyengine_us_data/datasets/cps/small_enhanced_cps.py"
],
branch,
checkpoint_volume,
env=env,
log_file=log_file,
checkpoint_stats=checkpoint_stats,
)
)
else:
print("Skipping small_enhanced_cps.py (--skip-enhanced-cps)")
for future in as_completed(phase5_futures):
future.result()

# Checkpoint the build log so it survives preemption
log_file.flush()
Expand Down Expand Up @@ -857,6 +877,7 @@ def build_datasets(
upload_requested=upload,
stage_only=stage_only,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
)
pipeline_volume.commit()
print("Pipeline artifacts committed to shared volume")
Expand All @@ -871,6 +892,7 @@ def build_datasets(
validate_and_maybe_upload_datasets(
upload=upload,
skip_enhanced_cps=skip_enhanced_cps,
require_small_enhanced_cps=not skip_stage_5,
env=env,
stage_only=stage_only,
run_id=run_id,
Expand All @@ -890,6 +912,7 @@ def main(
clear_checkpoints: bool = False,
skip_tests: bool = False,
skip_enhanced_cps: bool = False,
skip_stage_5: bool = False,
stage_only: bool = False,
run_id: str = "",
):
Expand All @@ -905,6 +928,7 @@ def main(
clear_checkpoints=clear_checkpoints,
skip_tests=skip_tests,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
stage_only=stage_only,
run_id=run_id,
)
Expand Down
16 changes: 15 additions & 1 deletion policyengine_us_data/datasets/cps/cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def add_rent(self, cps: h5py.File, person: DataFrame, household: DataFrame):
cps["real_estate_taxes"][mask] = imputed_values["real_estate_taxes"]


TEMPORARY_TAKEUP_SOURCE_ANCHORS = ("snap_reported", "ssi_reported")


def _drop_persisted_dataset_variables(file_path, variable_names):
with h5py.File(file_path, "a") as dataset_file:
for variable_name in variable_names:
if variable_name in dataset_file:
del dataset_file[variable_name]


@pipeline_node(
PipelineNode(
id="add_takeup",
Expand Down Expand Up @@ -636,10 +646,14 @@ def add_takeup(self):
data["age"],
)

for source_anchor in ("snap_reported", "ssi_reported"):
for source_anchor in TEMPORARY_TAKEUP_SOURCE_ANCHORS:
data.pop(source_anchor, None)

self.save_dataset(data)
_drop_persisted_dataset_variables(
self.file_path,
TEMPORARY_TAKEUP_SOURCE_ANCHORS,
)


def add_marketplace_plan_benchmark_ratio(self):
Expand Down
14 changes: 14 additions & 0 deletions policyengine_us_data/stage_contracts/dataset_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class _Stage1ArtifactSpec:
required_for_stage_2: bool = False
yearless_alias: bool = False
skip_when_enhanced_cps_skipped: bool = False
skip_when_stage_5_skipped: bool = False


_STAGE_1_ARTIFACTS: tuple[_Stage1ArtifactSpec, ...] = (
Expand Down Expand Up @@ -84,6 +85,7 @@ class _Stage1ArtifactSpec:
period=2024,
substage_id="1d_enhanced_cps_reweighting",
skip_when_enhanced_cps_skipped=True,
skip_when_stage_5_skipped=True,
),
_Stage1ArtifactSpec(
filename="stratified_extended_cps_2024.h5",
Expand All @@ -99,6 +101,7 @@ class _Stage1ArtifactSpec:
period=2024,
substage_id="1f_source_imputation",
required_for_stage_2=True,
skip_when_stage_5_skipped=True,
),
_Stage1ArtifactSpec(
filename="source_imputed_stratified_extended_cps.h5",
Expand All @@ -108,6 +111,7 @@ class _Stage1ArtifactSpec:
substage_id="1f_source_imputation",
required_for_stage_2=True,
yearless_alias=True,
skip_when_stage_5_skipped=True,
),
_Stage1ArtifactSpec(
filename="policy_data.db",
Expand Down Expand Up @@ -154,19 +158,22 @@ def build_dataset_build_output_contract(
upload_requested: bool = False,
stage_only: bool = False,
skip_enhanced_cps: bool = False,
skip_stage_5: bool = False,
) -> StageContract:
"""Build the Stage 1 handoff contract from copied pipeline artifacts."""

artifacts_dir = Path(artifacts_dir)
parameters = {
"period": 2024,
"skip_enhanced_cps": skip_enhanced_cps,
"skip_stage_5": skip_stage_5,
"stage_only": stage_only,
"upload_requested": upload_requested,
}
outputs = _stage_1_outputs(
artifacts_dir=artifacts_dir,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
)
execution = _execution_record(
checkpoint_stats=checkpoint_stats,
Expand All @@ -193,6 +200,7 @@ def build_dataset_build_output_contract(
substages=_stage_1_substages(
outputs=outputs,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
),
execution=execution,
metadata={
Expand All @@ -207,12 +215,15 @@ def _stage_1_outputs(
*,
artifacts_dir: Path,
skip_enhanced_cps: bool,
skip_stage_5: bool,
) -> tuple[ArtifactRef, ...]:
outputs: list[ArtifactRef] = []
missing_required: list[str] = []
for spec in _STAGE_1_ARTIFACTS:
if skip_enhanced_cps and spec.skip_when_enhanced_cps_skipped:
continue
if skip_stage_5 and spec.skip_when_stage_5_skipped:
continue
artifact_path = artifacts_dir / spec.filename
if not artifact_path.exists():
if spec.required:
Expand Down Expand Up @@ -276,6 +287,7 @@ def _stage_1_substages(
*,
outputs: tuple[ArtifactRef, ...],
skip_enhanced_cps: bool,
skip_stage_5: bool,
) -> tuple[SubstageRecord, ...]:
output_by_substage: dict[str, list[ArtifactRef]] = {
substage_id: [] for substage_id in _SUBSTAGE_IDS
Expand All @@ -290,6 +302,8 @@ def _stage_1_substages(
status = "completed"
if substage_id == "1d_enhanced_cps_reweighting" and skip_enhanced_cps:
status = "skipped"
if substage_id == "1f_source_imputation" and skip_stage_5:
status = "skipped"
reuse_mode = "checkpointable"
if substage_id in {
"1a_raw_data_download",
Expand Down
Loading
Loading