Skip to content

Commit 5355150

Browse files
authored
Adds model validator to ensure consistent renaming (#9)
* Adds model validator to ensure renamed models are updated throughout the project * Adds tests for renaming code * Uses project dicts to tidy up tests * Adds routine "get_all_matches" to "classList.py"
1 parent 1abd5ca commit 5355150

File tree

4 files changed

+349
-235
lines changed

4 files changed

+349
-235
lines changed

RAT/classlist.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ def get_names(self) -> list[str]:
217217
"""
218218
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]
219219

220+
def get_all_matches(self, value: Any) -> list[tuple]:
221+
"""Return a list of all (index, field) tuples where the value of the field is equal to the given value.
222+
223+
Parameters
224+
----------
225+
value : str
226+
The value we are searching for in the ClassList.
227+
228+
Returns
229+
-------
230+
: list [tuple]
231+
A list of (index, field) tuples matching the given value.
232+
"""
233+
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
234+
if getattr(element, field) == value]
235+
220236
def _validate_name_field(self, input_args: dict[str, Any]) -> None:
221237
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
222238
used within the ClassList.

RAT/project.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""
22

3+
import collections
4+
import contextlib
35
import copy
46
import functools
57
import numpy as np
@@ -46,7 +48,7 @@ class Geometries(StrEnum):
4648
'custom_files': 'CustomFile',
4749
'data': 'Data',
4850
'layers': 'Layer',
49-
'contrasts': 'Contrast'
51+
'contrasts': 'Contrast',
5052
}
5153

5254
values_defined_in = {'backgrounds.value_1': 'background_parameters',
@@ -70,6 +72,23 @@ class Geometries(StrEnum):
7072
'contrasts.resolution': 'resolutions',
7173
}
7274

75+
AllFields = collections.namedtuple('AllFields', ['attribute', 'fields'])
76+
model_names_used_in = {'background_parameters': AllFields('backgrounds', ['value_1', 'value_2', 'value_3', 'value_4',
77+
'value_5']),
78+
'resolution_parameters': AllFields('resolutions', ['value_1', 'value_2', 'value_3', 'value_4',
79+
'value_5']),
80+
'parameters': AllFields('layers', ['thickness', 'SLD', 'roughness']),
81+
'data': AllFields('contrasts', ['data']),
82+
'backgrounds': AllFields('contrasts', ['background']),
83+
'bulk_in': AllFields('contrasts', ['nba']),
84+
'bulk_out': AllFields('contrasts', ['nbs']),
85+
'scalefactors': AllFields('contrasts', ['scalefactor']),
86+
'resolutions': AllFields('contrasts', ['resolution']),
87+
}
88+
89+
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds',
90+
'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts']
91+
7392

7493
class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
7594
"""Defines the input data for a reflectivity calculation in RAT.
@@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
120139
layers: ClassList = ClassList()
121140
contrasts: ClassList = ClassList()
122141

142+
_all_names: dict
143+
123144
@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
124145
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
125146
'contrasts')
@@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
133154
return value
134155

135156
def model_post_init(self, __context: Any) -> None:
136-
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
137-
ClassList routines to control revalidation.
157+
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all
158+
defined parameters and wraps ClassList routines to control revalidation.
138159
"""
139160
for field_name, model in model_in_classlist.items():
140161
field = getattr(self, field_name)
@@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None:
145166
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
146167
sigma=np.inf))
147168

169+
self._all_names = self.get_all_names()
170+
148171
# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
149172
# model, handle errors and reset previous values if necessary.
150-
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
151-
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
152-
'contrasts']
153173
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
154174
for class_list in class_lists:
155175
attribute = getattr(self, class_list)
156176
for methodName in methods_to_wrap:
157177
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))
158178

179+
@model_validator(mode='after')
180+
def update_renamed_models(self) -> 'Project':
181+
"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
182+
for class_list in class_lists:
183+
old_names = self._all_names[class_list]
184+
new_names = getattr(self, class_list).get_names()
185+
if len(old_names) == len(new_names):
186+
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
187+
for (old_name, new_name) in name_diff:
188+
with contextlib.suppress(KeyError):
189+
model_names_list = getattr(self, model_names_used_in[class_list].attribute)
190+
all_matches = model_names_list.get_all_matches(old_name)
191+
for (index, field) in all_matches:
192+
setattr(model_names_list[index], field, new_name)
193+
self._all_names = self.get_all_names()
194+
return self
195+
159196
@model_validator(mode='after')
160197
def cross_check_model_values(self) -> 'Project':
161198
"""Certain model fields should contain values defined elsewhere in the project."""
@@ -185,6 +222,10 @@ def __repr__(self):
185222
output += value.value + '\n\n'
186223
return output
187224

225+
def get_all_names(self):
226+
"""Record the names of all models defined in the project."""
227+
return {class_list: getattr(self, class_list).get_names() for class_list in class_lists}
228+
188229
def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
189230
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
190231

tests/test_classlist.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,17 @@ def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None:
476476
assert class_list.get_names() == expected_names
477477

478478

479+
@pytest.mark.parametrize(["class_list", "expected_matches"], [
480+
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), [(0, 'name')]),
481+
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob', id='Alice')]), [(0, 'name'), (1, 'id')]),
482+
(ClassList([InputAttributes(surname='Morgan'), InputAttributes(surname='Terwilliger')]), []),
483+
(ClassList(InputAttributes()), []),
484+
])
485+
def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) -> None:
486+
"""We should get a list of (index, field) tuples matching the given value in the ClassList."""
487+
assert class_list.get_all_matches("Alice") == expected_matches
488+
489+
479490
@pytest.mark.parametrize("input_dict", [
480491
({'name': 'Eve'}),
481492
({'surname': 'Polastri'}),

0 commit comments

Comments
 (0)