Skip to content

Commit 220904d

Browse files
authored
Refactors Controls models into a single model (#43)
* Refactors controls models into a single model * Reviews test suite * Updates defaults for controls parameters * Updates examples * Updates str methods, switching from repr where appropriate * Rewrites model __str__ methods
1 parent 4ec08d2 commit 220904d

19 files changed

+692
-466
lines changed

RATpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from RATpy import events, models
22
from RATpy.classlist import ClassList
3-
from RATpy.controls import set_controls
3+
from RATpy.controls import Controls
44
from RATpy.project import Project
55
from RATpy.run import run
66
from RATpy.utils import plotting
77

8-
__all__ = ["ClassList", "Project", "run", "set_controls", "models", "events", "plotting"]
8+
__all__ = ["models", "events", "ClassList", "Controls", "Project", "run", "plotting"]

RATpy/classlist.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterable, Sequence
99
from typing import Any, Union
1010

11+
import numpy as np
1112
import prettytable
1213

1314

@@ -52,19 +53,30 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field
5253

5354
super().__init__(init_list)
5455

55-
def __repr__(self):
56+
def __str__(self):
5657
try:
5758
[model.__dict__ for model in self.data]
5859
except AttributeError:
59-
output = repr(self.data)
60+
output = str(self.data)
6061
else:
6162
if any(model.__dict__ for model in self.data):
6263
table = prettytable.PrettyTable()
6364
table.field_names = ["index"] + [key.replace("_", " ") for key in self.data[0].__dict__]
64-
table.add_rows([[index] + list(model.__dict__.values()) for index, model in enumerate(self.data)])
65+
table.add_rows(
66+
[
67+
[index]
68+
+ list(
69+
f"{'Data array: ['+' x '.join(str(i) for i in v.shape) if v.size > 0 else '['}]"
70+
if isinstance(v, np.ndarray)
71+
else str(v)
72+
for v in model.__dict__.values()
73+
)
74+
for index, model in enumerate(self.data)
75+
]
76+
)
6577
output = table.get_string()
6678
else:
67-
output = repr(self.data)
79+
output = str(self.data)
6880
return output
6981

7082
def __setitem__(self, index: int, item: object) -> None:

RATpy/controls.py

Lines changed: 106 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,148 +1,143 @@
1-
from dataclasses import dataclass, field
2-
from typing import Literal, Union
1+
import warnings
32

43
import prettytable
5-
from pydantic import BaseModel, Field, ValidationError, field_validator
4+
from pydantic import (
5+
BaseModel,
6+
Field,
7+
ValidationError,
8+
ValidatorFunctionWrapHandler,
9+
field_validator,
10+
model_serializer,
11+
model_validator,
12+
)
613

714
from RATpy.utils.custom_errors import custom_pydantic_validation_error
815
from RATpy.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies
916

10-
11-
@dataclass(frozen=True)
12-
class Controls:
13-
"""The full set of controls parameters required for the compiled RAT code."""
17+
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"]
18+
update_fields = ["updateFreq", "updatePlotFreq"]
19+
fields = {
20+
"calculate": common_fields,
21+
"simplex": [*common_fields, "xTolerance", "funcTolerance", "maxFuncEvals", "maxIterations", *update_fields],
22+
"de": [
23+
*common_fields,
24+
"populationSize",
25+
"fWeight",
26+
"crossoverProbability",
27+
"strategy",
28+
"targetValue",
29+
"numGenerations",
30+
*update_fields,
31+
],
32+
"ns": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
33+
"dream": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
34+
}
35+
36+
37+
class Controls(BaseModel, validate_assignment=True, extra="forbid"):
38+
"""The full set of controls parameters for all five procedures that are required for the compiled RAT code."""
1439

1540
# All Procedures
1641
procedure: Procedures = Procedures.Calculate
1742
parallel: Parallel = Parallel.Single
1843
calcSldDuringFit: bool = False
19-
resampleParams: list[float] = field(default_factory=list[0.9, 50.0])
20-
display: Display = Display.Iter
21-
# Simplex
22-
xTolerance: float = 1.0e-6
23-
funcTolerance: float = 1.0e-6
24-
maxFuncEvals: int = 10000
25-
maxIterations: int = 1000
26-
updateFreq: int = -1
27-
updatePlotFreq: int = 1
28-
# DE
29-
populationSize: int = 20
30-
fWeight: float = 0.5
31-
crossoverProbability: float = 0.8
32-
strategy: Strategies = Strategies.RandomWithPerVectorDither.value
33-
targetValue: float = 1.0
34-
numGenerations: int = 500
35-
# NS
36-
nLive: int = 150
37-
nMCMC: float = 0.0
38-
propScale: float = 0.1
39-
nsTolerance: float = 0.1
40-
# Dream
41-
nSamples: int = 20000
42-
nChains: int = 10
43-
jumpProbability: float = 0.5
44-
pUnitGamma: float = 0.2
45-
boundHandling: BoundHandling = BoundHandling.Reflect
46-
adaptPCR: bool = True
47-
48-
49-
class Calculate(BaseModel, validate_assignment=True, extra="forbid"):
50-
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
51-
52-
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
53-
parallel: Parallel = Parallel.Single
54-
calcSldDuringFit: bool = False
5544
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
5645
display: Display = Display.Iter
57-
58-
@field_validator("resampleParams")
59-
@classmethod
60-
def check_resample_params(cls, resampleParams):
61-
if not 0 < resampleParams[0] < 1:
62-
raise ValueError("resampleParams[0] must be between 0 and 1")
63-
if resampleParams[1] < 0:
64-
raise ValueError("resampleParams[1] must be greater than or equal to 0")
65-
return resampleParams
66-
67-
def __repr__(self) -> str:
68-
table = prettytable.PrettyTable()
69-
table.field_names = ["Property", "Value"]
70-
table.add_rows([[k, v] for k, v in self.__dict__.items()])
71-
return table.get_string()
72-
73-
74-
class Simplex(Calculate):
75-
"""Defines the additional fields for the simplex procedure."""
76-
77-
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
46+
# Simplex
7847
xTolerance: float = Field(1.0e-6, gt=0.0)
7948
funcTolerance: float = Field(1.0e-6, gt=0.0)
8049
maxFuncEvals: int = Field(10000, gt=0)
8150
maxIterations: int = Field(1000, gt=0)
82-
updateFreq: int = -1
83-
updatePlotFreq: int = 1
84-
85-
86-
class DE(Calculate):
87-
"""Defines the additional fields for the Differential Evolution procedure."""
88-
89-
procedure: Literal[Procedures.DE] = Procedures.DE
51+
# Simplex and DE
52+
updateFreq: int = 1
53+
updatePlotFreq: int = 20
54+
# DE
9055
populationSize: int = Field(20, ge=1)
9156
fWeight: float = 0.5
9257
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
9358
strategy: Strategies = Strategies.RandomWithPerVectorDither
9459
targetValue: float = Field(1.0, ge=1.0)
9560
numGenerations: int = Field(500, ge=1)
96-
97-
98-
class NS(Calculate):
99-
"""Defines the additional fields for the Nested Sampler procedure."""
100-
101-
procedure: Literal[Procedures.NS] = Procedures.NS
61+
# NS
10262
nLive: int = Field(150, ge=1)
10363
nMCMC: float = Field(0.0, ge=0.0)
10464
propScale: float = Field(0.1, gt=0.0, lt=1.0)
10565
nsTolerance: float = Field(0.1, ge=0.0)
106-
107-
108-
class Dream(Calculate):
109-
"""Defines the additional fields for the Dream procedure."""
110-
111-
procedure: Literal[Procedures.Dream] = Procedures.Dream
66+
# Dream
11267
nSamples: int = Field(20000, ge=0)
11368
nChains: int = Field(10, gt=0)
11469
jumpProbability: float = Field(0.5, gt=0.0, lt=1.0)
11570
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
11671
boundHandling: BoundHandling = BoundHandling.Reflect
11772
adaptPCR: bool = True
11873

74+
@model_validator(mode="wrap")
75+
def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandler) -> "Controls":
76+
"""Raise a warning if the user sets fields that apply to other procedures."""
77+
model_input = self
78+
try:
79+
input_dict = model_input.__dict__
80+
except AttributeError:
81+
input_dict = model_input
82+
83+
validated_self = None
84+
try:
85+
validated_self = handler(self)
86+
except ValidationError as exc:
87+
procedure = input_dict.get("procedure", Procedures.Calculate)
88+
custom_error_msgs = {
89+
"extra_forbidden": f'Extra inputs are not permitted. The fields for the "{procedure}"'
90+
f' controls procedure are:\n '
91+
f'{", ".join(fields.get("procedure", []))}\n',
92+
}
93+
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
94+
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
95+
96+
if isinstance(model_input, validated_self.__class__):
97+
# This is for changing fields in a defined model
98+
changed_fields = [key for key in input_dict if input_dict[key] != validated_self.__dict__[key]]
99+
elif isinstance(model_input, dict):
100+
# This is for a newly-defined model
101+
changed_fields = input_dict.keys()
102+
else:
103+
raise ValueError('The input to the "Controls" model is invalid.')
104+
105+
new_procedure = validated_self.procedure
106+
allowed_fields = fields[new_procedure]
107+
for field in changed_fields:
108+
if field not in allowed_fields:
109+
incorrect_procedures = [key for (key, value) in fields.items() if field in value]
110+
warnings.warn(
111+
f'\nThe current controls procedure is "{new_procedure}", but the property'
112+
f' "{field}" applies instead to the {", ".join(incorrect_procedures)} procedure.\n\n'
113+
f' The fields for the "{new_procedure}" controls procedure are:\n'
114+
f' {", ".join(fields[new_procedure])}\n',
115+
stacklevel=2,
116+
)
117+
118+
return validated_self
119+
120+
@field_validator("resampleParams")
121+
@classmethod
122+
def check_resample_params(cls, values: list[float]) -> list[float]:
123+
"""Make sure each of the two values of resampleParams satisfy their conditions."""
124+
if not 0 < values[0] < 1:
125+
raise ValueError("resampleParams[0] must be between 0 and 1")
126+
if values[1] < 0:
127+
raise ValueError("resampleParams[1] must be greater than or equal to 0")
128+
return values
119129

120-
def set_controls(
121-
procedure: Procedures = Procedures.Calculate,
122-
**properties,
123-
) -> Union[Calculate, Simplex, DE, NS, Dream]:
124-
"""Returns the appropriate controls model given the specified procedure."""
125-
controls = {
126-
Procedures.Calculate: Calculate,
127-
Procedures.Simplex: Simplex,
128-
Procedures.DE: DE,
129-
Procedures.NS: NS,
130-
Procedures.Dream: Dream,
131-
}
130+
@model_serializer
131+
def serialize(self):
132+
"""Filter fields so only those applying to the chosen procedure are serialized."""
133+
return {model_field: getattr(self, model_field) for model_field in fields[self.procedure]}
132134

133-
try:
134-
model = controls[procedure](**properties)
135-
except KeyError:
136-
members = list(Procedures.__members__.values())
137-
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
138-
raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None
139-
except ValidationError as exc:
140-
custom_error_msgs = {
141-
"extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}'
142-
f' controls procedure are:\n '
143-
f'{", ".join(controls[procedure].model_fields.keys())}\n',
144-
}
145-
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
146-
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
135+
def __repr__(self) -> str:
136+
fields_repr = ", ".join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.model_dump().items())
137+
return f"{self.__repr_name__()}({fields_repr})"
147138

148-
return model
139+
def __str__(self) -> str:
140+
table = prettytable.PrettyTable()
141+
table.field_names = ["Property", "Value"]
142+
table.add_rows([[k, v] for k, v in self.model_dump().items()])
143+
return table.get_string()

RATpy/examples/absorption/absorption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
)
142142

143143
# Now make a controls block
144-
controls = RAT.set_controls(parallel="contrasts", resampleParams=[0.9, 150.0])
144+
controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0])
145145

146146
# Run the code and plot the results
147147
problem, results = RAT.run(problem, controls)

RATpy/examples/domains/domains_custom_XY.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
model=["Domain Layer"],
6464
)
6565

66-
controls = RAT.set_controls()
66+
controls = RAT.Controls()
6767
problem, results = RAT.run(problem, controls)
6868

6969
RAT.plotting.plot_ref_sld(problem, results, True)

RATpy/examples/domains/domains_custom_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
model=["Alloy domains"],
4141
)
4242

43-
controls = RAT.set_controls()
43+
controls = RAT.Controls()
4444

4545
problem, results = RAT.run(problem, controls)
4646
RAT.plotting.plot_ref_sld(problem, results, True)

RATpy/examples/domains/domains_standard_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070

7171
# Now we can run our simulation as usual, and plot the results
72-
controls = RAT.set_controls()
72+
controls = RAT.Controls()
7373

7474
problem, results = RAT.run(problem, controls)
7575
RAT.plotting.plot_ref_sld(problem, results, True)

RATpy/examples/languages/run_custom_file_languages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
path = pathlib.Path(__file__).parent.resolve()
1111

1212
project = setup_problem.make_example_problem()
13-
controls = RAT.set_controls()
13+
controls = RAT.Controls()
1414

1515
# Python
1616
start = time.time()

RATpy/examples/non_polarised/DSPC_custom_XY.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
model=["DSPC Model"],
135135
)
136136

137-
controls = RAT.set_controls()
137+
controls = RAT.Controls()
138138

139139
problem, results = RAT.run(problem, controls)
140140
RAT.plotting.plot_ref_sld(problem, results, True)

RATpy/examples/non_polarised/DSPC_custom_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
model=["DSPC Model"],
116116
)
117117

118-
controls = RAT.set_controls()
118+
controls = RAT.Controls()
119119

120120
problem, results = RAT.run(problem, controls)
121121
RAT.plotting.plot_ref_sld(problem, results, True)

0 commit comments

Comments
 (0)