Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/derive-us-program-entities.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Derive US program-statistics entity from variable metadata instead of duplicating it in the program list.
34 changes: 20 additions & 14 deletions src/policyengine/tax_benefit_models/us/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@
)
from policyengine.utils.errors import format_conditional_error_detail

US_PROGRAMS = {
"income_tax": {"entity": "tax_unit", "is_tax": True},
"employee_payroll_tax": {"entity": "tax_unit", "is_tax": True},
"state_income_tax": {"entity": "tax_unit", "is_tax": True},
"snap": {"entity": "spm_unit", "is_tax": False},
"tanf": {"entity": "spm_unit", "is_tax": False},
"ssi": {"entity": "person", "is_tax": False},
"social_security": {"entity": "person", "is_tax": False},
"medicare_cost": {"entity": "person", "is_tax": False},
"medicaid": {"entity": "person", "is_tax": False},
"eitc": {"entity": "tax_unit", "is_tax": False},
"ctc": {"entity": "tax_unit", "is_tax": False},
# Map of US program-statistics variable name -> program metadata. The
# entity for each program is derived from the variable's own metadata
# at runtime (see ``_validate_program_statistics_config`` and
# ``economic_impact_analysis``), so this list cannot silently drift
# when policyengine-us moves a variable between entities.
US_PROGRAMS: dict[str, dict] = {
"income_tax": {"is_tax": True},
"employee_payroll_tax": {"is_tax": True},
"state_income_tax": {"is_tax": True},
"snap": {"is_tax": False},
"tanf": {"is_tax": False},
"ssi": {"is_tax": False},
"social_security": {"is_tax": False},
"medicare_cost": {"is_tax": False},
"medicaid": {"is_tax": False},
"eitc": {"is_tax": False},
"ctc": {"is_tax": False},
}


Expand Down Expand Up @@ -95,7 +100,7 @@ def _validate_program_statistics_config(
missing_outputs: set[tuple[str, str]] = set()

simulations = (baseline_simulation, reform_simulation)
for program_name, program_info in US_PROGRAMS.items():
for program_name in US_PROGRAMS:
for simulation in simulations:
model_version = simulation.tax_benefit_model_version
try:
Expand Down Expand Up @@ -153,13 +158,14 @@ def economic_impact_analysis(
income_variable="household_net_income",
)

model_version = baseline_simulation.tax_benefit_model_version
program_statistics = []
for program_name, program_info in US_PROGRAMS.items():
stats = ProgramStatistics(
baseline_simulation=baseline_simulation,
reform_simulation=reform_simulation,
program_name=program_name,
entity=program_info["entity"],
entity=model_version.get_variable(program_name).entity,
is_tax=program_info["is_tax"],
)
stats.run()
Expand Down
10 changes: 9 additions & 1 deletion tests/test_us_program_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path):

_validate_program_statistics_config(baseline, reform)

model_version = baseline.tax_benefit_model_version
results = {}
for program_name, program_info in US_PROGRAMS.items():
stats = ProgramStatistics(
baseline_simulation=baseline,
reform_simulation=reform,
program_name=program_name,
entity=program_info["entity"],
entity=model_version.get_variable(program_name).entity,
is_tax=program_info["is_tax"],
)
stats.run()
Expand Down Expand Up @@ -144,3 +145,10 @@ def test_us_program_statistics_config_fails_before_simulation_run(
_validate_program_statistics_config(baseline, reform)

assert "medicare_cost" in str(exc_info.value)


def test_us_programs_entities_match_model_metadata():
for program_name in US_PROGRAMS:
assert program_name in us_latest.variables_by_name, (
f"{program_name} is not defined in the US model"
)
Loading