Skip to content

Commit bbfebd4

Browse files
authored
Adds additional validators and "write_script" routine (#15)
* Adds code to remove layers for non-standard layers model types * Adds code to ensure protected parameters cannot be removed * Adds code to restore default domain ratios when switching calc_type * Adds validators to "Data" model * Adds "write_script" routine to "project.py" * Addresses review comments * Remove specific wrap data tests from "test_project", instead using new Data __eq__ method * Adds test for Data.__eq__ to improve test coverage * Adds fixture to setup and teardown temp directory
1 parent c580435 commit bbfebd4

File tree

5 files changed

+396
-147
lines changed

5 files changed

+396
-147
lines changed

RAT/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from RAT.classlist import ClassList
2+
from RAT.project import Project

RAT/models.py

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
5+
from typing import Any
56

67
try:
78
from enum import StrEnum
@@ -43,7 +44,6 @@ class Languages(StrEnum):
4344
class Priors(StrEnum):
4445
Uniform = 'uniform'
4546
Gaussian = 'gaussian'
46-
Jeffreys = 'jeffreys'
4747

4848

4949
class Types(StrEnum):
@@ -52,7 +52,19 @@ class Types(StrEnum):
5252
Function = 'function'
5353

5454

55-
class Background(BaseModel, validate_assignment=True, extra='forbid'):
55+
class RATModel(BaseModel):
56+
"""A BaseModel where enums are represented by their value."""
57+
def __repr__(self):
58+
fields_repr = (', '.join(repr(v) if a is None else
59+
f'{a}={v.value!r}' if isinstance(v, StrEnum) else
60+
f'{a}={v!r}'
61+
for a, v in self.__repr_args__()
62+
)
63+
)
64+
return f'{self.__repr_name__()}({fields_repr})'
65+
66+
67+
class Background(RATModel, validate_assignment=True, extra='forbid'):
5668
"""Defines the Backgrounds in RAT."""
5769
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
5870
type: Types = Types.Constant
@@ -63,7 +75,7 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):
6375
value_5: str = ''
6476

6577

66-
class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
78+
class Contrast(RATModel, validate_assignment=True, extra='forbid'):
6779
"""Groups together all of the components of the model."""
6880
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
6981
data: str = ''
@@ -76,7 +88,7 @@ class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
7688
model: list[str] = []
7789

7890

79-
class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
91+
class ContrastWithRatio(RATModel, validate_assignment=True, extra='forbid'):
8092
"""Groups together all of the components of the model including domain terms."""
8193
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
8294
data: str = ''
@@ -90,20 +102,20 @@ class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
90102
model: list[str] = []
91103

92104

93-
class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
105+
class CustomFile(RATModel, validate_assignment=True, extra='forbid'):
94106
"""Defines the files containing functions to run when using custom models."""
95107
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number), min_length=1)
96108
filename: str = ''
97109
language: Languages = Languages.Python
98110
path: str = 'pwd' # Should later expand to find current file path
99111

100112

101-
class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
113+
class Data(RATModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
102114
"""Defines the dataset required for each contrast."""
103115
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number), min_length=1)
104-
data: np.ndarray[float] = np.empty([0, 3])
105-
data_range: list[float] = []
106-
simulation_range: list[float] = [0.005, 0.7]
116+
data: np.ndarray[np.float64] = np.empty([0, 3])
117+
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
118+
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)
107119

108120
@field_validator('data')
109121
@classmethod
@@ -120,22 +132,79 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:
120132

121133
@field_validator('data_range', 'simulation_range')
122134
@classmethod
123-
def check_list_elements(cls, limits: list[float], info: ValidationInfo) -> list[float]:
124-
"""The data range and simulation range must contain exactly two parameters."""
125-
if len(limits) != 2:
126-
raise ValueError(f'{info.field_name} must contain exactly two values')
135+
def check_min_max(cls, limits: list[float], info: ValidationInfo) -> list[float]:
136+
"""The data range and simulation range maximum must be greater than the minimum."""
137+
if limits[0] > limits[1]:
138+
raise ValueError(f'{info.field_name} "min" value is greater than the "max" value')
127139
return limits
128140

129-
# Also need model validators for data range compared to data etc -- need more details.
141+
def model_post_init(self, __context: Any) -> None:
142+
"""If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
143+
set to the min and max values of the first column (assumed to be q) of the supplied data.
144+
"""
145+
if len(self.data[:, 0]) > 0:
146+
data_min = np.min(self.data[:, 0])
147+
data_max = np.max(self.data[:, 0])
148+
for field in ["data_range", "simulation_range"]:
149+
if field not in self.model_fields_set:
150+
getattr(self, field).extend([data_min, data_max])
151+
152+
@model_validator(mode='after')
153+
def check_ranges(self) -> 'Data':
154+
"""The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
155+
of the "simulation_range" field must lie outside of the range of the supplied data.
156+
"""
157+
if len(self.data[:, 0]) > 0:
158+
data_min = np.min(self.data[:, 0])
159+
data_max = np.max(self.data[:, 0])
160+
if "data_range" in self.model_fields_set and (self.data_range[0] < data_min or
161+
self.data_range[1] > data_max):
162+
raise ValueError(f'The data_range value of: {self.data_range} must lie within the min/max values of '
163+
f'the data: [{data_min}, {data_max}]')
164+
if "simulation_range" in self.model_fields_set and (self.simulation_range[0] > data_min or
165+
self.simulation_range[1] < data_max):
166+
raise ValueError(f'The simulation_range value of: {self.simulation_range} must lie outside of the '
167+
f'min/max values of the data: [{data_min}, {data_max}]')
168+
return self
169+
170+
def __eq__(self, other: Any) -> bool:
171+
if isinstance(other, BaseModel):
172+
# When comparing instances of generic types for equality, as long as all field values are equal,
173+
# only require their generic origin types to be equal, rather than exact type equality.
174+
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
175+
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
176+
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__
177+
178+
return (
179+
self_type == other_type
180+
and self.name == other.name
181+
and (self.data == other.data).all()
182+
and self.data_range == other.data_range
183+
and self.simulation_range == other.simulation_range
184+
and self.__pydantic_private__ == other.__pydantic_private__
185+
and self.__pydantic_extra__ == other.__pydantic_extra__
186+
)
187+
else:
188+
return NotImplemented # delegate to the other item in the comparison
189+
190+
def __repr__(self):
191+
"""Only include the name if the data is empty."""
192+
fields_repr = (f"name={self.name!r}" if self.data.size == 0 else
193+
", ".join(repr(v) if a is None else
194+
f"{a}={v!r}"
195+
for a, v in self.__repr_args__()
196+
)
197+
)
198+
return f'{self.__repr_name__()}({fields_repr})'
130199

131200

132-
class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
201+
class DomainContrast(RATModel, validate_assignment=True, extra='forbid'):
133202
"""Groups together the layers required for each domain."""
134203
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number), min_length=1)
135204
model: list[str] = []
136205

137206

138-
class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
207+
class Layer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
139208
"""Combines parameters into defined layers."""
140209
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
141210
thickness: str = ''
@@ -145,7 +214,7 @@ class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_nam
145214
hydrate_with: Hydration = Hydration.BulkOut
146215

147216

148-
class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
217+
class AbsorptionLayer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
149218
"""Combines parameters into defined layers including absorption terms."""
150219
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
151220
thickness: str = ''
@@ -156,7 +225,7 @@ class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', popul
156225
hydrate_with: Hydration = Hydration.BulkOut
157226

158227

159-
class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
228+
class Parameter(RATModel, validate_assignment=True, extra='forbid'):
160229
"""Defines parameters needed to specify the model."""
161230
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number), min_length=1)
162231
min: float = 0.0
@@ -180,7 +249,7 @@ class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
180249
name: str = Field(frozen=True, min_length=1)
181250

182251

183-
class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
252+
class Resolution(RATModel, validate_assignment=True, extra='forbid'):
184253
"""Defines Resolutions in RAT."""
185254
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
186255
type: Types = Types.Constant

0 commit comments

Comments
 (0)