Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/disallow-uprating-formula-variables.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Disallow variable definitions that combine formula, adds/subtracts, and uprating computation modes.
36 changes: 22 additions & 14 deletions policyengine_core/commons/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,28 @@ def uprated(by: str = None, start_year: int = 2015) -> Callable:
"""

def uprater(variable: Type[Variable]) -> type:
if hasattr(variable, f"formula_{start_year}"):
return variable

formula = variable.formula if hasattr(variable, "formula") else None
formula_names = [
name for name in variable.__dict__ if name.startswith("formula")
]
if formula_names:
raise ValueError(
f'Variable "{variable.__name__}" uses @uprated and has a formula. '
"Uprating is only supported for input variables; formulas "
"should handle their own time behavior explicitly."
)
if "adds" in variable.__dict__ or "subtracts" in variable.__dict__:
raise ValueError(
f'Variable "{variable.__name__}" uses @uprated and has '
"adds/subtracts. Uprating is only supported for input "
"variables without formula, adds/subtracts, or uprating "
"metadata."
)
if "uprating" in variable.__dict__:
raise ValueError(
f'Variable "{variable.__name__}" uses @uprated and has '
"uprating. Uprating is only supported for input variables "
"without formula, adds/subtracts, or uprating metadata."
)

variable.metadata = {
"uprating": by,
Expand All @@ -368,16 +386,6 @@ def formula_start_year(entity, period, parameters):
last_year_parameter = getattr(last_year_parameter, name)
uprating = current_parameter / last_year_parameter
old = entity(variable.__name__, period.last_year)
# Use numpy.all on the element-wise equality with 0; Python's
# ``all(old)`` checks truthiness of each element, so a single
# non-zero value makes the guard ``False`` even when every
# other value is zero — which defeated the "no values were
# inputted" short-circuit and caused uprating to run on top
# of a formula fall-back output (bug M1).
if (formula is not None) and np.all(old == 0):
# If no values have been inputted, don't uprate and
# instead use the previous formula on the current period.
return formula(entity, period, parameters)
return uprating * old

formula_start_year.__name__ = f"formula_{start_year}"
Expand Down
37 changes: 37 additions & 0 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def __init__(self, baseline_variable=None):
)
self.formulas = self.set_formulas(formulas_attr)

self.check_computation_modes()

if unexpected_attrs:
raise ValueError(
'Unexpected attributes in definition of variable "{}": {!r}'.format(
Expand All @@ -329,6 +331,41 @@ def __init__(self, baseline_variable=None):

# ----- Setters used to build the variable ----- #

@property
def uprating(self):
return getattr(self, "_uprating", None)

@uprating.setter
def uprating(self, value):
old_value = getattr(self, "_uprating", None)
self._uprating = value
if hasattr(self, "formulas"):
try:
self.check_computation_modes()
except ValueError:
self._uprating = old_value
raise

def get_computation_modes(self):
computation_modes = []
if self.formulas:
computation_modes.append("formula")
if self.adds is not None or self.subtracts is not None:
computation_modes.append("adds/subtracts")
if self.uprating is not None:
computation_modes.append("uprating")
return computation_modes

def check_computation_modes(self):
computation_modes = self.get_computation_modes()
if len(computation_modes) > 1:
raise ValueError(
f'Variable "{self.name}" mixes computation modes: '
f"{' and '.join(computation_modes)}. Variables must use at "
"most one of formula, adds/subtracts, or uprating; plain "
"input or constant variables should use none."
)

def set(
self,
attributes,
Expand Down
21 changes: 0 additions & 21 deletions tests/core/test_medium_fixes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""Regression tests for a batch of surgical Medium-severity fixes.

* M1 — ``@uprated`` short-circuit used Python ``all()`` instead of
``numpy.all(old == 0)``. Previously the guard returned ``True`` only
when the first element was zero (truthiness) rather than "no values
have been inputted".
* M8 — ``SimulationBuilder`` multi-axis ``linspace`` branch divided by
``axis_count - 1``, crashing on single-point axes.
* M10 — ``Dataset.download`` parsed ``release://org/repo/tag/file`` with
Expand All @@ -17,7 +13,6 @@

import datetime

import numpy as np
import pytest

from policyengine_core.variables.config import VALUE_TYPES
Expand All @@ -42,19 +37,3 @@ def test_single_point_axis_does_not_divide_by_zero(persons):
# After the fix, a single-point axis produces the ``axis["min"]`` value.
builder.expand_axes()
assert builder.get_input("salary", "2018-11") == pytest.approx([500])


def test_all_numpy_guard_triggers_on_all_zero_old():
"""Bug M1: ``np.all(old == 0)`` must be used, not Python ``all(old)``.

Python ``all([1, 0, 0])`` == True (because 1 is truthy), so the guard
would NOT fire. ``np.all([1, 0, 0] == 0)`` == False, which correctly
says "not all zero".
"""
old = np.array([1, 0, 0])
# Python truthy-check semantics: ``all([1, 0, 0])`` -> False because
# 0 is falsy. For the reversed test case with all zeros:
all_zero = np.array([0, 0, 0])
# The fix uses ``np.all(old == 0)`` which is True iff every element is 0.
assert np.all(all_zero == 0)
assert not np.all(old == 0)
196 changes: 180 additions & 16 deletions tests/core/variables/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import policyengine_core.country_template as country_template
import policyengine_core.country_template.situation_examples
from policyengine_core.country_template.entities import Person
from policyengine_core.model_api import Variable
from policyengine_core.model_api import Variable, uprated
from policyengine_core.periods import ETERNITY, MONTH
from policyengine_core.simulations import SimulationBuilder
from policyengine_core.tools import assert_near
Expand Down Expand Up @@ -555,34 +555,156 @@ def formula():


def test_one_formula_one_add():
check_error_at_add_variable(
tax_benefit_system,
variable__one_formula_one_add,
'Variable "{name}" has a formula and an add or subtract'.format(
name="variable__one_formula_one_add"
),
)
with raises(
ValueError,
match='Variable "variable__one_formula_one_add" mixes computation modes: formula and adds/subtracts',
):
tax_benefit_system.add_variable(variable__one_formula_one_add)


class variable__one_formula_one_subtract(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with one formula and one subtract."
adds = ["pass"]
subtracts = ["pass"]

def formula():
pass


def test_one_formula_one_subtract():
check_error_at_add_variable(
tax_benefit_system,
variable__one_formula_one_subtract,
'Variable "{name}" has a formula and an add or subtract'.format(
name="variable__one_formula_one_subtract"
),
)
with raises(
ValueError,
match='Variable "variable__one_formula_one_subtract" mixes computation modes: formula and adds/subtracts',
):
tax_benefit_system.add_variable(variable__one_formula_one_subtract)


class variable__one_formula_one_uprating(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with one formula and one uprating."
uprating = "uprating.index"

def formula():
pass


def test_one_formula_one_uprating():
with raises(
ValueError,
match='Variable "variable__one_formula_one_uprating" mixes computation modes: formula and uprating',
):
tax_benefit_system.add_variable(variable__one_formula_one_uprating)


class variable__one_add_one_uprating(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with one add and one uprating."
adds = ["pass"]
uprating = "uprating.index"


def test_one_add_one_uprating():
with raises(
ValueError,
match='Variable "variable__one_add_one_uprating" mixes computation modes: adds/subtracts and uprating',
):
tax_benefit_system.add_variable(variable__one_add_one_uprating)


class variable__one_subtract_one_uprating(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with one subtract and one uprating."
subtracts = ["pass"]
uprating = "uprating.index"


def test_one_subtract_one_uprating():
with raises(
ValueError,
match='Variable "variable__one_subtract_one_uprating" mixes computation modes: adds/subtracts and uprating',
):
tax_benefit_system.add_variable(variable__one_subtract_one_uprating)


def test_uprated_decorator_rejects_existing_formula():
with raises(
ValueError,
match='Variable "variable__uprated_decorator_one_formula" uses @uprated and has a formula',
):

@uprated("uprating.index")
class variable__uprated_decorator_one_formula(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with @uprated and one formula."

def formula():
pass


def test_uprated_decorator_rejects_existing_adds():
with raises(
ValueError,
match='Variable "variable__uprated_decorator_one_add" uses @uprated and has adds/subtracts',
):

@uprated("uprating.index")
class variable__uprated_decorator_one_add(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with @uprated and one add."
adds = ["pass"]


def test_uprated_decorator_rejects_existing_subtracts():
with raises(
ValueError,
match='Variable "variable__uprated_decorator_one_subtract" uses @uprated and has adds/subtracts',
):

@uprated("uprating.index")
class variable__uprated_decorator_one_subtract(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with @uprated and one subtract."
subtracts = ["pass"]


def test_uprated_decorator_rejects_existing_uprating():
with raises(
ValueError,
match='Variable "variable__uprated_decorator_one_uprating" uses @uprated and has uprating',
):

@uprated("uprating.index")
class variable__uprated_decorator_one_uprating(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with @uprated and one uprating."
uprating = "uprating.index"


def test_uprated_decorator_allows_input_variable():
@uprated("uprating.index")
class variable__uprated_decorator_input(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Input variable with @uprated."

assert hasattr(variable__uprated_decorator_input, "formula_2015")


class variable__one_formula(Variable):
Expand Down Expand Up @@ -629,6 +751,48 @@ def test_one_subtract():
assert len(variable.subtracts)


class variable__one_add_one_subtract(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Variable with one add and one subtract."
adds = ["pass"]
subtracts = ["pass"]


def test_one_add_one_subtract():
tax_benefit_system.add_variable(variable__one_add_one_subtract)
variable = tax_benefit_system.variables["variable__one_add_one_subtract"]
assert len(variable.adds)
assert len(variable.subtracts)


def test_runtime_uprating_assignment_rejects_existing_adds():
variable = variable__one_add()

with raises(
ValueError,
match='Variable "variable__one_add" mixes computation modes: adds/subtracts and uprating',
):
variable.uprating = "uprating.index"
assert variable.uprating is None


class variable__runtime_uprating_input(Variable):
value_type = int
entity = Person
definition_period = MONTH
label = "Input variable with runtime uprating assignment."


def test_runtime_uprating_assignment_allows_input_variable():
variable = variable__runtime_uprating_input()

variable.uprating = "uprating.index"

assert variable.uprating == "uprating.index"


class variable__no_label(Variable):
value_type = int
entity = Person
Expand Down
Loading