Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/943.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update the production data-build runtime to policyengine-us 1.690.7 and harden Modal pipeline resume behavior.
40 changes: 27 additions & 13 deletions modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _calibration_package_parameters(
) -> dict:
"""Return manifest parameters that affect package construction."""
effective_parallel = bool(chunked_matrix and parallel_matrix)
return {
params = {
"workers": workers if not chunked_matrix else None,
"n_clones": n_clones,
"target_config": target_config,
Expand All @@ -153,6 +153,7 @@ def _calibration_package_parameters(
"parallel_matrix": effective_parallel,
"num_matrix_workers": num_matrix_workers if effective_parallel else None,
}
return {key: value for key, value in params.items() if value is not None}


def get_pinned_sha(branch: str) -> str:
Expand Down Expand Up @@ -1092,6 +1093,15 @@ def run_pipeline(
expected_input_identities=package_inputs,
expected_parameters=package_parameters,
)
if not package_reuse.reusable:
previous = package_reuse.manifest
print(f" Package reuse invalidated: {package_reuse.reason}")
if previous is not None:
print(f" prior status: {previous.status}")
print(f" prior parameters: {previous.parameters}")
print(f" expected parameters: {package_parameters}")
print(f" prior inputs: {previous.input_identities}")
print(f" expected inputs: {package_inputs}")
if package_reuse.reusable:
_mark_step_reused(
meta,
Expand Down Expand Up @@ -1502,9 +1512,8 @@ def run_pipeline(
vol=pipeline_volume,
)

pipeline_volume.reload()

# Now wait for H5 builds to finish
# Now wait for H5 builds to finish. Do not reload the shared
# volume until the child jobs release SQLite handles.
print(" Waiting for regional H5 build...")
regional_h5_result = regional_h5_handle.get()
regional_msg = (
Expand All @@ -1514,6 +1523,20 @@ def run_pipeline(
)
print(f" Regional H5: {regional_msg}")

national_h5_result = None
if national_h5_handle is not None:
print(" Waiting for national H5 build...")
national_h5_result = national_h5_handle.get()
national_msg = (
national_h5_result.get("message", national_h5_result)
if isinstance(national_h5_result, dict)
else national_h5_result
)
print(f" National H5: {national_msg}")

pipeline_volume.reload()
staging_volume.reload()

if isinstance(regional_h5_result, dict) and regional_h5_result.get(
"fingerprint"
):
Expand Down Expand Up @@ -1542,16 +1565,7 @@ def run_pipeline(
)
active_step_manifest = national_h5_manifest

national_h5_result = None
if national_h5_handle is not None:
print(" Waiting for national H5 build...")
national_h5_result = national_h5_handle.get()
national_msg = (
national_h5_result.get("message", national_h5_result)
if isinstance(national_h5_result, dict)
else national_h5_result
)
print(f" National H5: {national_msg}")
if isinstance(national_h5_result, dict) and national_h5_result.get(
"fingerprint"
):
Expand Down
83 changes: 82 additions & 1 deletion modal_app/step_manifests/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from __future__ import annotations

import os
import hashlib
import json
import sqlite3
from dataclasses import asdict, dataclass, field, fields
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -100,13 +103,91 @@ def artifacts_dir(run_id: str) -> Path:
return Path(artifacts_dir_for_run(run_id))


def _quote_sql_identifier(identifier: str) -> str:
return '"' + identifier.replace('"', '""') + '"'


def _canonical_sqlite_value(value):
if isinstance(value, bytes):
return {"__bytes__": value.hex()}
return value


def _canonical_sqlite_sha256(path: Path) -> str:
"""Hash logical SQLite contents instead of mutable file metadata."""
digest = hashlib.sha256()

def update(payload) -> None:
digest.update(
json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
)
digest.update(b"\n")

with sqlite3.connect(f"file:{path}?mode=ro", uri=True) as conn:
conn.row_factory = sqlite3.Row
schema_rows = conn.execute(
"""
SELECT type, name, tbl_name, sql
FROM sqlite_master
WHERE name NOT LIKE 'sqlite_%'
ORDER BY type, name
"""
).fetchall()
update(
{
"schema": [
{
"type": row["type"],
"name": row["name"],
"tbl_name": row["tbl_name"],
"sql": row["sql"],
}
for row in schema_rows
]
}
)

table_names = [row["name"] for row in schema_rows if row["type"] == "table"]
for table_name in table_names:
columns = [
row["name"]
for row in conn.execute(
f"PRAGMA table_info({_quote_sql_identifier(table_name)})"
)
]
quoted_columns = [_quote_sql_identifier(column) for column in columns]
select_columns = ", ".join(quoted_columns)
order_columns = ", ".join(quoted_columns)
for row in conn.execute(
f"""
SELECT {select_columns}
FROM {_quote_sql_identifier(table_name)}
ORDER BY {order_columns}
"""
):
update(
{
"table": table_name,
"row": [
_canonical_sqlite_value(row[column]) for column in columns
],
}
)
return digest.hexdigest()


def artifact_identity(path: str | Path) -> dict:
artifact = ArtifactReference.from_path(path)
return {
identity = {
"path": artifact.path,
"size_bytes": artifact.size_bytes,
"sha256": artifact.sha256,
}
if Path(path).suffix == ".db":
identity["sha256"] = _canonical_sqlite_sha256(Path(path))
identity.pop("size_bytes", None)
identity["identity_kind"] = "sqlite_content"
return identity


def artifact_identities(paths: dict[str, str | Path]) -> dict:
Expand Down
14 changes: 11 additions & 3 deletions policyengine_us_data/calibration/calibration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,12 @@ def create_target_groups(
pairs = sorted(
level_df[["domain_variable", "variable"]]
.drop_duplicates()
.itertuples(index=False, name=None)
.itertuples(index=False, name=None),
key=lambda pair: (
pair[0] is not None,
"" if pair[0] is None else str(pair[0]),
str(pair[1]),
),
)
else:
pairs = [(None, v) for v in sorted(level_df["variable"].unique())]
Expand All @@ -353,8 +358,11 @@ def create_target_groups(
var_mask = (
(targets_df["variable"] == var_name) & level_mask & ~processed_mask
)
if has_domain and domain_var is not None:
var_mask &= targets_df["domain_variable"] == domain_var
if has_domain:
if domain_var is None:
var_mask &= targets_df["domain_variable"].isna()
else:
var_mask &= targets_df["domain_variable"] == domain_var

if not var_mask.any():
continue
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: 3.14",
]
dependencies = [
"policyengine-us>=1.690.6",
"policyengine-us>=1.690.7",
# policyengine-core 3.26.1 is the current 3.26.x runtime and includes the fix for
# PolicyEngine/policyengine-core#482 (user-set ETERNITY inputs lost
# after _invalidate_all_caches) and is required by policyengine-us 1.682.1+.
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/calibration/test_drop_target_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def sample_data():


class TestDropTargetGroups:
def test_target_groups_separate_null_and_string_domains(self):
targets_df = pd.DataFrame(
{
"variable": ["person_count", "person_count", "snap"],
"domain_variable": [None, "age", "snap"],
"geographic_id": ["US", "US", "US"],
"value": [1000, 200, 300],
}
)

target_groups, group_info = create_target_groups(targets_df)

assert len(set(target_groups)) == 3
assert target_groups[0] != target_groups[1]
assert any("Person Count" in info for info in group_info)
assert any("AGE Person Count" in info for info in group_info)

def test_drops_matching_group(self, sample_data):
targets_df, X, target_groups, group_info = sample_data
n_before = len(targets_df)
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion validation/stage_1/test_xw_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def test_xw_matches_stacked_sim():
hierarchical_domains=["snap"],
rerandomize_takeup=True,
county_level=False,
workers=2,
# Keep this validation serial. In Modal's Python 3.14 runtime the
# short-lived ProcessPool path can leave the test waiting on futures
# after the worker processes have exited.
workers=1,
)

takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS]
Expand Down