Skip to content

Commit c9b74c0

Browse files
authored
Add Models and Project for RAT API (#6)
* Adds "models.py" with initial draft of pydantic models for API classes * Adds validators for pydantic models * Fixes parameter names in background model * Adds new model ProtectedParameter * Adds "project.py" with initial draft of the high level "Project" class * Adds "model_post_init" routine for the "project" model * Adds "__repr__" routine for the "project" model * Adds code to work with updated ClassList * Moves validators to cross-check project fields from "models.py" to "project.py" * Replaces annotated validators with single field validator in "project.py" * Add contrasts to cross-checking model validator in "project.py" * Changes data model to accept numpy array * Adds docs and modifies the Project class's "model_post_init" to ensure all ClassLists have the correct "_class_handle" * Adds tests "test_models.py" * Removes unused routine "get_all_names" from "project.py" * Adds tests "test_project.py" * Tidies up model and project classes and tests * Adds code to fix enums for all python versions * Adds code to stop "test_repr" in "test_project.py" writing to console * Addresses review comments
1 parent 0b3565f commit c9b74c0

File tree

7 files changed

+723
-16
lines changed

7 files changed

+723
-16
lines changed

RAT/classlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
253253
Raised if the input list defines objects of different types.
254254
"""
255255
if not (all(isinstance(element, self._class_handle) for element in input_list)):
256-
raise ValueError(f"Input list contains elements of type other than '{self._class_handle}'")
256+
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
257257

258258
def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
259259
"""Return the object with the given value of the name_field attribute in the ClassList.

RAT/models.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""The models module. Contains the pydantic models used by RAT to store project parameters."""
2+
3+
import numpy as np
4+
from pydantic import BaseModel, Field, FieldValidationInfo, field_validator, model_validator
5+
6+
try:
7+
from enum import StrEnum
8+
except ImportError:
9+
from strenum import StrEnum
10+
11+
12+
def int_sequence():
13+
"""Iterate through integers for use as model counters."""
14+
num = 1
15+
while True:
16+
yield str(num)
17+
num += 1
18+
19+
20+
# Create a counter for each model
21+
background_number = int_sequence()
22+
contrast_number = int_sequence()
23+
custom_file_number = int_sequence()
24+
data_number = int_sequence()
25+
domain_contrast_number = int_sequence()
26+
layer_number = int_sequence()
27+
parameter_number = int_sequence()
28+
resolution_number = int_sequence()
29+
30+
31+
class Hydration(StrEnum):
32+
None_ = 'none'
33+
BulkIn = 'bulk in'
34+
BulkOut = 'bulk out'
35+
Oil = 'oil'
36+
37+
38+
class Languages(StrEnum):
39+
Python = 'python'
40+
Matlab = 'matlab'
41+
42+
43+
class Priors(StrEnum):
44+
Uniform = 'uniform'
45+
Gaussian = 'gaussian'
46+
Jeffreys = 'jeffreys'
47+
48+
49+
class Types(StrEnum):
50+
Constant = 'constant'
51+
Data = 'data'
52+
Function = 'function'
53+
54+
55+
class Background(BaseModel, validate_assignment=True, extra='forbid'):
56+
"""Defines the Backgrounds in RAT."""
57+
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number))
58+
type: Types = Types.Constant
59+
value_1: str = ''
60+
value_2: str = ''
61+
value_3: str = ''
62+
value_4: str = ''
63+
value_5: str = ''
64+
65+
66+
class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
67+
"""Groups together all of the components of the model."""
68+
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number))
69+
data: str = ''
70+
background: str = ''
71+
nba: str = ''
72+
nbs: str = ''
73+
scalefactor: str = ''
74+
resolution: str = ''
75+
resample: bool = False
76+
model: list[str] = [] # But how many strings? How to deal with this?
77+
78+
79+
class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
80+
"""Defines the files containing functions to run when using custom models."""
81+
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number))
82+
filename: str = ''
83+
language: Languages = Languages.Python
84+
path: str = 'pwd' # Should later expand to find current file path
85+
86+
87+
class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
88+
"""Defines the dataset required for each contrast."""
89+
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number))
90+
data: np.ndarray[float] = np.empty([0, 3])
91+
data_range: list[float] = []
92+
simulation_range: list[float] = [0.005, 0.7]
93+
94+
@field_validator('data')
95+
@classmethod
96+
def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:
97+
"""The data must be a two-dimensional array containing at least three columns."""
98+
try:
99+
data.shape[1]
100+
except IndexError:
101+
raise ValueError('"data" must have at least two dimensions')
102+
else:
103+
if data.shape[1] < 3:
104+
raise ValueError('"data" must have at least three columns')
105+
return data
106+
107+
@field_validator('data_range', 'simulation_range')
108+
@classmethod
109+
def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) -> list[float]:
110+
"""The data range and simulation range must contain exactly two parameters."""
111+
if len(limits) != 2:
112+
raise ValueError(f'{info.field_name} must contain exactly two values')
113+
return limits
114+
115+
# Also need model validators for data range compared to data etc -- need more details.
116+
117+
118+
class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
119+
"""Groups together the layers required for each domain."""
120+
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number))
121+
model: list[str] = []
122+
123+
124+
class Layer(BaseModel, validate_assignment=True, extra='forbid'):
125+
"""Combines parameters into defined layers."""
126+
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number))
127+
thickness: str = ''
128+
SLD: str = ''
129+
roughness: str = ''
130+
hydration: str = ''
131+
hydrate_with: Hydration = Hydration.BulkOut
132+
133+
134+
class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
135+
"""Defines parameters needed to specify the model"""
136+
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number))
137+
min: float = 0.0
138+
value: float = 0.0
139+
max: float = 0.0
140+
fit: bool = False
141+
prior_type: Priors = Priors.Uniform
142+
mu: float = 0.0
143+
sigma: float = np.inf
144+
145+
@model_validator(mode='after')
146+
def check_value_in_range(self) -> 'Parameter':
147+
"""The value of a parameter must lie within its defined bounds."""
148+
if self.value < self.min or self.value > self.max:
149+
raise ValueError(f'value {self.value} is not within the defined range: {self.min} <= value <= {self.max}')
150+
return self
151+
152+
153+
class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
154+
"""A Parameter with a fixed name."""
155+
name: str = Field(frozen=True)
156+
157+
158+
class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
159+
"""Defines Resolutions in RAT."""
160+
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number))
161+
type: Types = Types.Constant
162+
value_1: str = ''
163+
value_2: str = ''
164+
value_3: str = ''
165+
value_4: str = ''
166+
value_5: str = ''

RAT/project.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""
2+
3+
import numpy as np
4+
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator
5+
from typing import Any
6+
7+
from RAT.classlist import ClassList
8+
import RAT.models
9+
10+
try:
11+
from enum import StrEnum
12+
except ImportError:
13+
from strenum import StrEnum
14+
15+
16+
class CalcTypes(StrEnum):
17+
NonPolarised = 'non polarised'
18+
Domains = 'domains'
19+
OilWater = 'oil water'
20+
21+
22+
class ModelTypes(StrEnum):
23+
CustomLayers = 'custom layers'
24+
CustomXY = 'custom xy'
25+
StandardLayers = 'standard layers'
26+
27+
28+
class Geometries(StrEnum):
29+
AirSubstrate = 'air/substrate'
30+
SubstrateLiquid = 'substrate/liquid'
31+
32+
33+
# Map project fields to pydantic models
34+
model_in_classlist = {'parameters': 'Parameter',
35+
'bulk_in': 'Parameter',
36+
'bulk_out': 'Parameter',
37+
'qz_shifts': 'Parameter',
38+
'scalefactors': 'Parameter',
39+
'background_parameters': 'Parameter',
40+
'resolution_parameters': 'Parameter',
41+
'backgrounds': 'Background',
42+
'resolutions': 'Resolution',
43+
'custom_files': 'CustomFile',
44+
'data': 'Data',
45+
'layers': 'Layer',
46+
'contrasts': 'Contrast'
47+
}
48+
49+
50+
class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
51+
"""Defines the input data for a reflectivity calculation in RAT.
52+
53+
This class combines the data defined in each of the pydantic models included in "models.py" into the full set of
54+
inputs required for a reflectivity calculation.
55+
"""
56+
name: str = ''
57+
calc_type: CalcTypes = CalcTypes.NonPolarised
58+
model: ModelTypes = ModelTypes.StandardLayers
59+
geometry: Geometries = Geometries.AirSubstrate
60+
absorption: bool = False
61+
62+
parameters: ClassList = ClassList()
63+
64+
bulk_in: ClassList = ClassList(RAT.models.Parameter(name='SLD Air', min=0, value=0, max=0, fit=False,
65+
prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf))
66+
67+
bulk_out: ClassList = ClassList(RAT.models.Parameter(name='SLD D2O', min=6.2e-6, value=6.35e-6, max=6.35e-6,
68+
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0,
69+
sigma=np.inf))
70+
71+
qz_shifts: ClassList = ClassList(RAT.models.Parameter(name='Qz shift 1', min=-1e-4, value=0, max=1e-4, fit=False,
72+
prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf))
73+
74+
scalefactors: ClassList = ClassList(RAT.models.Parameter(name='Scalefactor 1', min=0.02, value=0.23, max=0.25,
75+
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0,
76+
sigma=np.inf))
77+
78+
background_parameters: ClassList = ClassList(RAT.models.Parameter(name='Background Param 1', min=1e-7, value=1e-6,
79+
max=1e-5, fit=False,
80+
prior_type=RAT.models.Priors.Uniform, mu=0,
81+
sigma=np.inf))
82+
83+
backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=RAT.models.Types.Constant.value,
84+
value_1='Background Param 1'))
85+
86+
resolution_parameters: ClassList = ClassList(RAT.models.Parameter(name='Resolution Param 1', min=0.01, value=0.03,
87+
max=0.05, fit=False,
88+
prior_type=RAT.models.Priors.Uniform, mu=0,
89+
sigma=np.inf))
90+
91+
resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=RAT.models.Types.Constant.value,
92+
value_1='Resolution Param 1'))
93+
94+
custom_files: ClassList = ClassList()
95+
data: ClassList = ClassList(RAT.models.Data(name='Simulation'))
96+
layers: ClassList = ClassList()
97+
contrasts: ClassList = ClassList()
98+
99+
@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
100+
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
101+
'contrasts')
102+
@classmethod
103+
def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
104+
"""Each of the data fields should be a ClassList of the appropriate model."""
105+
model_name = model_in_classlist[info.field_name]
106+
model = getattr(RAT.models, model_name)
107+
assert all(isinstance(element, model) for element in value), \
108+
f'"{info.field_name}" ClassList contains objects other than "{model_name}"'
109+
return value
110+
111+
def model_post_init(self, __context: Any) -> None:
112+
"""Initialises the class in the ClassLists for empty data fields, and sets protected parameters."""
113+
for field_name, model in model_in_classlist.items():
114+
field = getattr(self, field_name)
115+
if not hasattr(field, "_class_handle"):
116+
setattr(field, "_class_handle", getattr(RAT.models, model))
117+
118+
self.parameters.insert(0, RAT.models.ProtectedParameter(name='Substrate Roughness', min=1, value=3, max=5,
119+
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
120+
sigma=np.inf))
121+
122+
@model_validator(mode='after')
123+
def cross_check_model_values(self) -> 'Project':
124+
"""Certain model fields should contain values defined elsewhere in the project."""
125+
value_fields = ['value_1', 'value_2', 'value_3', 'value_4', 'value_5']
126+
self.check_allowed_values('backgrounds', value_fields, self.background_parameters.get_names())
127+
self.check_allowed_values('resolutions', value_fields, self.resolution_parameters.get_names())
128+
self.check_allowed_values('layers', ['thickness', 'SLD', 'roughness'], self.parameters.get_names())
129+
130+
self.check_allowed_values('contrasts', ['data'], self.data.get_names())
131+
self.check_allowed_values('contrasts', ['background'], self.backgrounds.get_names())
132+
self.check_allowed_values('contrasts', ['nba'], self.bulk_in.get_names())
133+
self.check_allowed_values('contrasts', ['nbs'], self.bulk_out.get_names())
134+
self.check_allowed_values('contrasts', ['scalefactor'], self.scalefactors.get_names())
135+
self.check_allowed_values('contrasts', ['resolution'], self.resolutions.get_names())
136+
return self
137+
138+
def __repr__(self):
139+
output = ''
140+
for key, value in self.__dict__.items():
141+
if value:
142+
output += f'{key.replace("_", " ").title() + ": " :-<100}\n\n'
143+
try:
144+
value.value # For enums
145+
except AttributeError:
146+
output += repr(value) + '\n\n'
147+
else:
148+
output += value.value + '\n\n'
149+
return output
150+
151+
def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
152+
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
153+
154+
Parameters
155+
----------
156+
attribute : str
157+
The attribute of Project being validated.
158+
field_list : list [str]
159+
The fields of the attribute to be checked for valid values.
160+
allowed_values : list [str]
161+
The list of allowed values for the fields given in field_list.
162+
163+
Raises
164+
------
165+
ValueError
166+
Raised if any field in field_list has a value not specified in allowed_values.
167+
"""
168+
class_list = getattr(self, attribute)
169+
for model in class_list:
170+
for field in field_list:
171+
value = getattr(model, field)
172+
if value and value not in allowed_values:
173+
setattr(model, field, '')
174+
raise ValueError(f'The parameter "{value}" has not been defined in the list of allowed values.')

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
numpy >= 1.20
12
pydantic >= 2.0.3
23
pytest >= 7.4.0
34
pytest-cov >= 4.1.0
5+
StrEnum >= 0.4.15
46
tabulate >= 0.9.0

0 commit comments

Comments
 (0)