Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/program-specs-metadata.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replaced the hard-coded program-statistics dict in `economic_impact_analysis` (US) with a structured `ProgramSpec` list and a `resolve_program_specs` helper that derives each program's entity from the model's variable metadata. Unknown variables now produce a single `ValueError` listing all problems at once with fuzzy-match suggestions, fixing the silent entity-drift class of bug tracked in #326.
27 changes: 10 additions & 17 deletions src/policyengine/tax_benefit_models/us/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
calculate_us_poverty_rates,
)

from .programs import US_PROGRAM_SPECS, resolve_program_specs


class PolicyReformAnalysis(BaseModel):
"""Complete policy reform analysis result."""
Expand Down Expand Up @@ -71,28 +73,19 @@ def economic_impact_analysis(
income_variable="household_net_income",
)

programs = {
"income_tax": {"entity": "tax_unit", "is_tax": True},
"payroll_tax": {"entity": "person", "is_tax": True},
"state_income_tax": {"entity": "tax_unit", "is_tax": True},
"snap": {"entity": "spm_unit", "is_tax": False},
"tanf": {"entity": "spm_unit", "is_tax": False},
"ssi": {"entity": "person", "is_tax": False},
"social_security": {"entity": "person", "is_tax": False},
"medicare": {"entity": "person", "is_tax": False},
"medicaid": {"entity": "person", "is_tax": False},
"eitc": {"entity": "tax_unit", "is_tax": False},
"ctc": {"entity": "tax_unit", "is_tax": False},
}
resolved_programs = resolve_program_specs(
US_PROGRAM_SPECS,
baseline_simulation.tax_benefit_model_version,
)

program_statistics = []
for program_name, program_info in programs.items():
for program in resolved_programs:
stats = ProgramStatistics(
baseline_simulation=baseline_simulation,
reform_simulation=reform_simulation,
program_name=program_name,
entity=program_info["entity"],
is_tax=program_info["is_tax"],
program_name=program.name,
entity=program.entity,
is_tax=program.is_tax,
)
stats.run()
program_statistics.append(stats)
Expand Down
117 changes: 117 additions & 0 deletions src/policyengine/tax_benefit_models/us/programs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Program-statistics specifications for the US model.

The program list used by ``economic_impact_analysis`` was previously a
hard-coded ``dict`` literal that duplicated each program's ``entity``
inline. That design is fragile — when ``policyengine-us`` moves a
variable between entities, or renames it, the hard-coded entity falls
out of sync silently and only fails at simulation time deep inside an
``Aggregate`` lookup.

This module replaces that pattern with:

* A structured :class:`ProgramSpec` declaring just the variable name
and whether the program is a tax (entity is *not* duplicated here).
* :func:`resolve_program_specs`, which validates every spec against the
model up front and derives ``entity`` from each variable's own
metadata. Unknown variables produce a single :class:`ValueError`
listing all problems at once, with fuzzy-match suggestions.

This is a step toward the durable design tracked in #326; it does not
yet derive the program list itself from model metadata (e.g. by
scanning for variables tagged as ``program``), but it removes the
entity-drift class of bug entirely.
"""

from __future__ import annotations

import difflib
from dataclasses import dataclass

from policyengine.core import TaxBenefitModelVersion


@dataclass(frozen=True)
class ProgramSpec:
"""Declarative entry for a single program in the program-statistics table.

Attributes:
name: Variable name in the model (also used as the display label).
is_tax: ``True`` for taxes (winner/loser sign is inverted), ``False``
for benefits.
"""

name: str
is_tax: bool


@dataclass(frozen=True)
class ResolvedProgram:
"""A :class:`ProgramSpec` after metadata-driven validation."""

name: str
entity: str
is_tax: bool


# US program list. Variable names must match what `policyengine-us`
# actually exposes; entity is derived from each variable's metadata
# at resolve-time and does not need to be repeated here.
US_PROGRAM_SPECS: list[ProgramSpec] = [
ProgramSpec(name="income_tax", is_tax=True),
ProgramSpec(name="employee_payroll_tax", is_tax=True),
ProgramSpec(name="household_state_income_tax", is_tax=True),
ProgramSpec(name="snap", is_tax=False),
ProgramSpec(name="tanf", is_tax=False),
ProgramSpec(name="ssi", is_tax=False),
ProgramSpec(name="social_security", is_tax=False),
ProgramSpec(name="medicare_cost", is_tax=False),
ProgramSpec(name="medicaid", is_tax=False),
ProgramSpec(name="eitc", is_tax=False),
ProgramSpec(name="ctc", is_tax=False),
]


def resolve_program_specs(
specs: list[ProgramSpec],
model_version: TaxBenefitModelVersion,
) -> list[ResolvedProgram]:
"""Validate every spec against the model and derive its entity.

Collects all unknown-variable errors into a single
:class:`ValueError` so the caller sees the full list of problems
at once instead of fixing them one at a time. Includes
``difflib`` suggestions for likely typos / renames.

Raises:
ValueError: if any spec references a variable not present in
``model_version``.
"""
known = model_version.variables_by_name
errors: list[str] = []
resolved: list[ResolvedProgram] = []

for spec in specs:
variable = known.get(spec.name)
if variable is None:
suggestions = difflib.get_close_matches(spec.name, known.keys(), n=3)
msg = f"{spec.name!r}: variable not found in model"
if suggestions:
msg += f" (did you mean: {', '.join(suggestions)}?)"
errors.append(msg)
continue
resolved.append(
ResolvedProgram(
name=spec.name,
entity=variable.entity,
is_tax=spec.is_tax,
)
)

if errors:
joined = "\n - ".join(errors)
raise ValueError(
f"Invalid program-statistics configuration ({len(errors)} "
f"unknown variables):\n - {joined}"
)

return resolved
89 changes: 89 additions & 0 deletions tests/test_program_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Tests for the structured program-spec mechanism (issue #326)."""

import pytest

from policyengine.core import TaxBenefitModelVersion, Variable
from policyengine.tax_benefit_models.us.programs import (
US_PROGRAM_SPECS,
ProgramSpec,
ResolvedProgram,
resolve_program_specs,
)


def _build_model(variables: dict[str, str]) -> TaxBenefitModelVersion:
"""Build a stub model with the given ``variable_name -> entity`` map."""
model = TaxBenefitModelVersion(id="stub")
for name, entity in variables.items():
model.add_variable(
Variable(
id=f"stub:{name}",
name=name,
entity=entity,
tax_benefit_model_version=model,
)
)
return model


def test_resolve_derives_entity_from_metadata():
model = _build_model(
{
"income_tax": "tax_unit",
"snap": "spm_unit",
}
)
specs = [
ProgramSpec(name="income_tax", is_tax=True),
ProgramSpec(name="snap", is_tax=False),
]

resolved = resolve_program_specs(specs, model)

assert resolved == [
ResolvedProgram(name="income_tax", entity="tax_unit", is_tax=True),
ResolvedProgram(name="snap", entity="spm_unit", is_tax=False),
]


def test_resolve_collects_all_unknowns_in_one_error():
model = _build_model({"income_tax": "tax_unit"})
specs = [
ProgramSpec(name="income_tax", is_tax=True),
ProgramSpec(name="payroll_tax", is_tax=True), # unknown
ProgramSpec(name="medicare", is_tax=False), # unknown
]

with pytest.raises(ValueError, match="2 unknown variables") as exc:
resolve_program_specs(specs, model)

msg = str(exc.value)
assert "'payroll_tax'" in msg
assert "'medicare'" in msg


def test_resolve_includes_fuzzy_match_suggestions():
model = _build_model(
{
"employee_payroll_tax": "tax_unit",
"medicare_cost": "person",
}
)
specs = [
ProgramSpec(name="payroll_tax", is_tax=True),
ProgramSpec(name="medicare", is_tax=False),
]

with pytest.raises(ValueError) as exc:
resolve_program_specs(specs, model)

msg = str(exc.value)
assert "employee_payroll_tax" in msg
assert "medicare_cost" in msg


def test_us_program_specs_has_no_duplicates():
names = [s.name for s in US_PROGRAM_SPECS]
assert len(names) == len(set(names)), (
"US_PROGRAM_SPECS must not contain duplicate variable names"
)
Loading