Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/sparse-validation-runtime.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use saved calibration diagnostics for the sparse enhanced CPS validation gate instead of rebuilding the full loss matrix.
73 changes: 13 additions & 60 deletions validation/stage_1/test_sparse_enhanced_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from policyengine_core.data import Dataset
from policyengine_core.reforms import Reform
from policyengine_us import Microsimulation
from policyengine_us_data.utils import (
ABSOLUTE_ERROR_SCALE_TARGETS,
build_loss_matrix,
print_reweighting_diagnostics,
)
from policyengine_us_data.utils import ABSOLUTE_ERROR_SCALE_TARGETS
from policyengine_us_data.storage import STORAGE_FOLDER


Expand Down Expand Up @@ -69,61 +65,18 @@ def test_sparse_poverty_rate_reasonable(sparse_sim):
# ── Reweighting and calibration checks ────────────────────────


@pytest.mark.filterwarnings("ignore:DataFrame is highly fragmented")
@pytest.mark.filterwarnings("ignore:The distutils package is deprecated")
@pytest.mark.filterwarnings(
"ignore:Series.__getitem__ treating keys as positions is deprecated"
)
@pytest.mark.filterwarnings(
"ignore:Setting an item of incompatible dtype is deprecated"
)
@pytest.mark.filterwarnings(
"ignore:Boolean Series key will be reindexed to match DataFrame index."
)
def test_sparse_ecps(sim):
data = sim.dataset.load_dataset()
optimised_weights = data["household_weight"]["2024"]

bad_targets = [
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
"state/RI/adjusted_gross_income/amount/-inf_1",
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
"state/RI/adjusted_gross_income/amount/-inf_1",
"nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All",
]

loss_matrix, targets_array = build_loss_matrix(sim.dataset, 2024)
scaled_zero_target_mask = loss_matrix.columns.isin(
ABSOLUTE_ERROR_SCALE_TARGETS.keys()
)
zero_mask = np.isclose(targets_array, 0.0, atol=0.1) & (~scaled_zero_target_mask)
bad_mask = loss_matrix.columns.isin(bad_targets)
keep_mask_bool = ~(zero_mask | bad_mask)
keep_idx = np.where(keep_mask_bool)[0]
loss_matrix_clean = loss_matrix.iloc[:, keep_idx]
targets_array_clean = targets_array[keep_idx]
assert loss_matrix_clean.shape[1] == targets_array_clean.size

percent_within_10 = print_reweighting_diagnostics(
optimised_weights,
loss_matrix_clean,
targets_array_clean,
"Sparse Solutions",
)
def test_sparse_ecps():
calibration_log = pd.read_csv("calibration_log.csv")
final_epoch = calibration_log["epoch"].max()
final_rows = calibration_log[calibration_log["epoch"] == final_epoch].copy()

assert not final_rows.empty, "No final-epoch calibration diagnostics found."

tolerance = 0.10 * final_rows["target"].abs()
for target_name, scale in ABSOLUTE_ERROR_SCALE_TARGETS.items():
tolerance.loc[final_rows["target_name"] == target_name] = 0.10 * scale

percent_within_10 = (final_rows["abs_error"] <= tolerance).mean() * 100
assert percent_within_10 > 60.0


Expand Down
Loading