Skip to content
1 change: 1 addition & 0 deletions changelog.d/951.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add worker-scoped setup and validation context contracts for local H5 builds.
88 changes: 64 additions & 24 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,34 @@ def _build_worker_bootstrap(
return bundle


def _build_worker_calibration_inputs(
*,
weights_path: Path,
geography_path: Path,
dataset_path: Path,
db_path: Path,
n_clones: int,
seed: int,
run_config_path: Path | None = None,
calibration_package_path: Path | None = None,
) -> Dict[str, object]:
"""Build the calibration input payload passed to H5 worker subprocesses."""

calibration_inputs: Dict[str, object] = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": seed,
}
if run_config_path is not None and run_config_path.exists():
calibration_inputs["run_config"] = str(run_config_path)
if calibration_package_path is not None and calibration_package_path.exists():
calibration_inputs["calibration_package"] = str(calibration_package_path)
return calibration_inputs


@pipeline_node(
PipelineNode(
id="coordinate_work_partition",
Expand Down Expand Up @@ -590,6 +618,7 @@ def run_phase(
handle = build_areas_worker.spawn(
branch=branch,
run_id=run_id,
scope="regional",
work_items=chunk,
calibration_inputs=calibration_inputs,
validate=validate,
Expand Down Expand Up @@ -685,8 +714,9 @@ def run_phase(
def build_areas_worker(
branch: str,
run_id: str,
scope: str,
work_items: List[Dict],
calibration_inputs: Dict[str, str],
calibration_inputs: Dict[str, object],
validate: bool = True,
) -> Dict:
"""
Expand All @@ -708,27 +738,35 @@ def build_areas_worker(
"--work-items",
work_items_json,
"--weights-path",
calibration_inputs["weights"],
str(calibration_inputs["weights"]),
"--dataset-path",
calibration_inputs["dataset"],
str(calibration_inputs["dataset"]),
"--db-path",
calibration_inputs["database"],
str(calibration_inputs["database"]),
"--output-dir",
str(output_dir),
"--scope",
scope,
"--run-id",
run_id,
"--artifacts-dir",
str(Path("/pipeline/artifacts") / run_id),
]
if "geography" in calibration_inputs:
worker_cmd.extend(["--geography-path", calibration_inputs["geography"]])
worker_cmd.extend(["--geography-path", str(calibration_inputs["geography"])])
if "calibration_package" in calibration_inputs:
worker_cmd.extend(
[
"--calibration-package-path",
calibration_inputs["calibration_package"],
str(calibration_inputs["calibration_package"]),
]
)
if "n_clones" in calibration_inputs:
worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])])
if "seed" in calibration_inputs:
worker_cmd.extend(["--seed", str(calibration_inputs["seed"])])
if "run_config" in calibration_inputs:
worker_cmd.extend(["--run-config-path", str(calibration_inputs["run_config"])])
repo_root = Path("/root/policyengine-us-data")
cal_dir = repo_root / "policyengine_us_data" / "calibration"
worker_cmd.extend(
Expand Down Expand Up @@ -1085,16 +1123,16 @@ def coordinate_publish(
)
print("All required pipeline artifacts found on volume.")

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": 42,
}
if calibration_package_path.exists():
calibration_inputs["calibration_package"] = str(calibration_package_path)
calibration_inputs = _build_worker_calibration_inputs(
weights_path=weights_path,
geography_path=geography_path,
dataset_path=dataset_path,
db_path=db_path,
n_clones=n_clones,
seed=42,
run_config_path=config_json_path,
calibration_package_path=calibration_package_path,
)
validate_artifacts(config_json_path, artifacts)

if validate:
Expand Down Expand Up @@ -1396,14 +1434,15 @@ def coordinate_national_publish(
)
print("All required national pipeline artifacts found.")

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": 42,
}
calibration_inputs = _build_worker_calibration_inputs(
weights_path=weights_path,
geography_path=geography_path,
dataset_path=dataset_path,
db_path=db_path,
n_clones=n_clones,
seed=42,
run_config_path=config_json_path,
)
validate_artifacts(
config_json_path,
artifacts,
Expand Down Expand Up @@ -1444,6 +1483,7 @@ def coordinate_national_publish(
worker_result = build_areas_worker.remote(
branch=branch,
run_id=run_id,
scope="national",
work_items=work_items,
calibration_inputs=calibration_inputs,
validate=validate,
Expand Down
179 changes: 91 additions & 88 deletions modal_app/worker_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from pathlib import Path
from typing import Any

import numpy as np


def _validate_in_subprocess(
h5_path,
Expand Down Expand Up @@ -149,6 +147,27 @@ def parse_args(argv: list[str] | None = None):
parser.add_argument("--dataset-path", required=True)
parser.add_argument("--db-path", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument(
"--scope",
choices=("regional", "national"),
required=True,
help="Worker bootstrap scope to use for this request batch",
)
parser.add_argument(
"--run-id",
default=None,
help="Pipeline run ID used for traceability and bootstrap lookup",
)
parser.add_argument(
"--artifacts-dir",
default=None,
help="Optional run-scoped pipeline artifacts directory containing bootstrap artifacts",
)
parser.add_argument(
"--run-config-path",
default=None,
help="Optional unified run configuration JSON used for traceability",
)
parser.add_argument(
"--geography-path",
default=None,
Expand Down Expand Up @@ -212,6 +231,35 @@ def _load_request_inputs_from_args(
return "work_items", tuple(json.loads(args.work_items))


def _build_publishing_inputs(*, args, run_id: str):
"""Build the traceability input bundle consumed by worker setup services."""

from policyengine_us_data.build_outputs.fingerprinting import (
PublishingInputBundle,
)

return PublishingInputBundle(
weights_path=Path(args.weights_path),
source_dataset_path=Path(args.dataset_path),
target_db_path=Path(args.db_path) if args.db_path else None,
exact_geography_path=(
Path(args.geography_path) if args.geography_path is not None else None
),
calibration_package_path=(
Path(args.calibration_package_path)
if args.calibration_package_path is not None
else None
),
run_config_path=(
Path(args.run_config_path) if args.run_config_path is not None else None
),
run_id=run_id,
version="",
n_clones=args.n_clones,
seed=args.seed,
)


def _build_kwargs_from_request(request) -> dict[str, Any]:
"""Translate a typed request into `build_h5(...)` keyword arguments."""

Expand Down Expand Up @@ -299,10 +347,9 @@ def _resolve_request_input(
def main(argv: list[str] | None = None):
args = parse_args(argv)

weights_path = Path(args.weights_path)
dataset_path = Path(args.dataset_path)
db_path = Path(args.db_path)
output_dir = Path(args.output_dir)
run_id = args.run_id or output_dir.name or "local-worker"

from policyengine_us_data.utils.takeup import (
SIMPLE_TAKEUP_VARS,
Expand All @@ -315,100 +362,56 @@ def main(argv: list[str] | None = None):

from policyengine_us_data.calibration.publish_local_area import (
build_h5,
load_calibration_geography,
)
from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog
from policyengine_us_data.build_outputs.requests import AreaBuildRequest
from policyengine_us_data.build_outputs.validation import ValidationPolicy
from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory

weights = np.load(weights_path)

from policyengine_us import Microsimulation

_sim = Microsimulation(dataset=str(dataset_path))
n_records = len(_sim.calculate("household_id", map_to="household").values)
del _sim

geography = load_calibration_geography(
weights_path=weights_path,
n_records=n_records,
n_clones=args.n_clones,
geography_path=(
Path(args.geography_path) if args.geography_path is not None else None
),
calibration_package_path=(
Path(args.calibration_package_path)
if args.calibration_package_path is not None
else None
),
)
print(
f"Loaded geography: "
f"{geography.n_clones} clones x "
f"{geography.n_records} records",
file=sys.stderr,
)
area_catalog = USAreaCatalog.default()
request_input_mode, request_inputs = _load_request_inputs_from_args(
args=args,
area_build_request_cls=AreaBuildRequest,
)

# ── Validation setup (once per worker) ──
validation_targets = None
training_mask_full = None
constraints_map = None
if not args.no_validate:
from sqlalchemy import create_engine
from policyengine_us_data.calibration.validate_staging import (
_query_all_active_targets,
_batch_stratum_constraints,
)
from policyengine_us_data.calibration.unified_calibration import (
load_target_config,
_match_rules,
)

engine = create_engine(f"sqlite:///{db_path}")
validation_targets = _query_all_active_targets(engine, args.period)
print(
f"Loaded {len(validation_targets)} validation targets",
file=sys.stderr,
)

# Apply exclude/include from validation config
if args.validation_config:
val_cfg = load_target_config(args.validation_config)
exc_rules = val_cfg.get("exclude", [])
if exc_rules:
exc_mask = _match_rules(validation_targets, exc_rules)
validation_targets = validation_targets[~exc_mask].reset_index(
drop=True
)
inc_rules = val_cfg.get("include", [])
if inc_rules:
inc_mask = _match_rules(validation_targets, inc_rules)
validation_targets = validation_targets[inc_mask].reset_index(drop=True)

# Compute training mask from training config
if args.target_config:
tr_cfg = load_target_config(args.target_config)
tr_inc = tr_cfg.get("include", [])
if tr_inc:
training_mask_full = np.asarray(
_match_rules(validation_targets, tr_inc),
dtype=bool,
)
else:
training_mask_full = np.ones(len(validation_targets), dtype=bool)
else:
training_mask_full = np.ones(len(validation_targets), dtype=bool)

# Batch-load constraints
stratum_ids = validation_targets["stratum_id"].unique().tolist()
constraints_map = _batch_stratum_constraints(engine, stratum_ids)
scope = args.scope
inputs = _build_publishing_inputs(args=args, run_id=run_id)

session = WorkerSessionFactory().create(
inputs=inputs,
scope=scope,
validation_policy=ValidationPolicy(enabled=not args.no_validate),
period=args.period,
target_config_path=Path(args.target_config) if args.target_config else None,
validation_config_path=(
Path(args.validation_config) if args.validation_config else None
),
artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None,
)
weights = session.weights.values
n_records = session.weights.n_records
geography = session.geography
validation_context = session.validation_context
validation_targets = (
validation_context.validation_targets
if validation_context is not None
else None
)
training_mask_full = (
validation_context.training_mask if validation_context is not None else None
)
constraints_map = (
validation_context.constraints_map if validation_context is not None else None
)
print(
"Worker session ready: "
f"scope={scope}, bootstrap={session.bootstrap_status}, "
f"{geography.n_clones} clones x {geography.n_records} records",
file=sys.stderr,
)
if validation_targets is not None:
print(
f"Validation ready: {len(validation_targets)} targets, "
f"{len(stratum_ids)} strata",
f"{len(constraints_map or {})} strata",
file=sys.stderr,
)

Expand Down Expand Up @@ -486,7 +489,7 @@ def main(argv: list[str] | None = None):
validation_targets=validation_targets,
training_mask_full=training_mask_full,
constraints_map=constraints_map,
db_path=str(db_path),
db_path=str(inputs.target_db_path),
period=args.period,
)
results["validation_rows"].extend(v_rows)
Expand Down
3 changes: 2 additions & 1 deletion policyengine_us_data/build_outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
seams rather than speculative placeholders. The current early slices support
H5 output request construction, exact calibration geography loading,
fingerprinting, clone-weight shape contracts, worker partitioning, source
dataset snapshot contracts, and introduced worker-bootstrap artifacts.
dataset snapshot contracts, introduced worker-bootstrap artifacts, and
worker-scoped session and validation context setup.
"""
Loading
Loading