Skip to content

Commit 1abd5ca

Browse files
authored
Adds wrapper code for ClassList routines in Project (#8)
* Adds code to wrap ClassLists and force revalidation * Adds auxiliary routines to "classlist.py" to allow iadd, set and del operations to be wrapped in "project.py" * Adds "_setattr" auxiliary routine to each model in "models.py" * Wraps the "_setattr" auxiliary routine for each model that requires cross-check validation * Adds module "CustomErrors.py", including routine "formatted_pydantic_error" to RAT.utils. * Removed code to wrap "__setattr__" in models.py * Adds tests for wrapping code. * Tidies up the "Project.check_allowed_values()" routine * Adds tests requiring testing equality of Data models.
1 parent c9b74c0 commit 1abd5ca

File tree

6 files changed

+613
-45
lines changed

6 files changed

+613
-45
lines changed

RAT/classlist.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,28 @@ def __repr__(self):
6363

6464
def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None:
6565
"""Assign the values of an existing object's attributes using a dictionary."""
66+
self._setitem(index, set_dict)
67+
68+
def _setitem(self, index: int, set_dict: dict[str, Any]) -> None:
69+
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
6670
self._validate_name_field(set_dict)
6771
for key, value in set_dict.items():
6872
setattr(self.data[index], key, value)
6973

74+
def __delitem__(self, index: int) -> None:
75+
"""Delete an object from the list by index."""
76+
self._delitem(index)
77+
78+
def _delitem(self, index: int) -> None:
79+
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
80+
del self.data[index]
81+
7082
def __iadd__(self, other: Sequence[object]) -> 'ClassList':
7183
"""Define in-place addition using the "+=" operator."""
84+
return self._iadd(other)
85+
86+
def _iadd(self, other: Sequence[object]) -> 'ClassList':
87+
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
7288
if not hasattr(self, '_class_handle'):
7389
self._class_handle = type(other[0])
7490
self._check_classes(self + other)

RAT/project.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""
22

3+
import copy
4+
import functools
35
import numpy as np
4-
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator
5-
from typing import Any
6+
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator, ValidationError
7+
from typing import Any, Callable
68

79
from RAT.classlist import ClassList
810
import RAT.models
11+
from RAT.utils.custom_errors import formatted_pydantic_error
912

1013
try:
1114
from enum import StrEnum
@@ -46,6 +49,27 @@ class Geometries(StrEnum):
4649
'contrasts': 'Contrast'
4750
}
4851

52+
values_defined_in = {'backgrounds.value_1': 'background_parameters',
53+
'backgrounds.value_2': 'background_parameters',
54+
'backgrounds.value_3': 'background_parameters',
55+
'backgrounds.value_4': 'background_parameters',
56+
'backgrounds.value_5': 'background_parameters',
57+
'resolutions.value_1': 'resolution_parameters',
58+
'resolutions.value_2': 'resolution_parameters',
59+
'resolutions.value_3': 'resolution_parameters',
60+
'resolutions.value_4': 'resolution_parameters',
61+
'resolutions.value_5': 'resolution_parameters',
62+
'layers.thickness': 'parameters',
63+
'layers.SLD': 'parameters',
64+
'layers.roughness': 'parameters',
65+
'contrasts.data': 'data',
66+
'contrasts.background': 'backgrounds',
67+
'contrasts.nba': 'bulk_in',
68+
'contrasts.nbs': 'bulk_out',
69+
'contrasts.scalefactor': 'scalefactors',
70+
'contrasts.resolution': 'resolutions',
71+
}
72+
4973

5074
class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
5175
"""Defines the input data for a reflectivity calculation in RAT.
@@ -109,7 +133,9 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
109133
return value
110134

111135
def model_post_init(self, __context: Any) -> None:
112-
"""Initialises the class in the ClassLists for empty data fields, and sets protected parameters."""
136+
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
137+
ClassList routines to control revalidation.
138+
"""
113139
for field_name, model in model_in_classlist.items():
114140
field = getattr(self, field_name)
115141
if not hasattr(field, "_class_handle"):
@@ -119,6 +145,17 @@ def model_post_init(self, __context: Any) -> None:
119145
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
120146
sigma=np.inf))
121147

148+
# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
149+
# model, handle errors and reset previous values if necessary.
150+
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
151+
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
152+
'contrasts']
153+
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
154+
for class_list in class_lists:
155+
attribute = getattr(self, class_list)
156+
for methodName in methods_to_wrap:
157+
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))
158+
122159
@model_validator(mode='after')
123160
def cross_check_model_values(self) -> 'Project':
124161
"""Certain model fields should contain values defined elsewhere in the project."""
@@ -170,5 +207,43 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
170207
for field in field_list:
171208
value = getattr(model, field)
172209
if value and value not in allowed_values:
173-
setattr(model, field, '')
174-
raise ValueError(f'The parameter "{value}" has not been defined in the list of allowed values.')
210+
raise ValueError(f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
211+
f'"{values_defined_in[attribute + "." + field]}".')
212+
213+
def _classlist_wrapper(self, class_list: 'ClassList', func: Callable):
214+
"""Defines the function used to wrap around ClassList routines to force revalidation.
215+
216+
Parameters
217+
----------
218+
class_list : ClassList
219+
The ClassList defined in the "Project" model that is being modified.
220+
func : Callable
221+
The routine being wrapped.
222+
223+
Returns
224+
-------
225+
wrapped_func : Callable
226+
The wrapped routine.
227+
"""
228+
@functools.wraps(func)
229+
def wrapped_func(*args, **kwargs):
230+
"""Run the given function and then revalidate the "Project" model. If any exception is raised, restore
231+
the previous state of the given ClassList and report details of the exception.
232+
"""
233+
previous_state = copy.deepcopy(getattr(class_list, 'data'))
234+
return_value = None
235+
try:
236+
return_value = func(*args, **kwargs)
237+
Project.model_validate(self)
238+
except ValidationError as e:
239+
setattr(class_list, 'data', previous_state)
240+
error_string = formatted_pydantic_error(e)
241+
# Use ANSI escape sequences to print error text in red
242+
print('\033[31m' + error_string + '\033[0m')
243+
except (TypeError, ValueError):
244+
setattr(class_list, 'data', previous_state)
245+
raise
246+
finally:
247+
del previous_state
248+
return return_value
249+
return wrapped_func

RAT/utils/custom_errors.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Defines routines for custom error handling in RAT."""
2+
3+
from pydantic import ValidationError
4+
5+
6+
def formatted_pydantic_error(error: ValidationError) -> str:
7+
"""Write a custom string format for pydantic validation errors.
8+
9+
Parameters
10+
----------
11+
error : pydantic.ValidationError
12+
A ValidationError produced by a pydantic model
13+
14+
Returns
15+
-------
16+
error_str : str
17+
A string giving details of the ValidationError in a custom format.
18+
"""
19+
num_errors = error.error_count()
20+
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'
21+
for this_error in error.errors():
22+
error_str += '\n'
23+
if this_error['loc']:
24+
error_str += ' '.join(this_error['loc']) + '\n'
25+
error_str += ' ' + this_error['msg']
26+
return error_str

tests/test_classlist.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: d
142142
two_name_class_list[0] = new_values
143143

144144

145+
def test_delitem(two_name_class_list: 'ClassList', one_name_class_list: 'ClassList') -> None:
146+
"""We should be able to delete elements from a ClassList with the del operator."""
147+
class_list = two_name_class_list
148+
del class_list[1]
149+
assert class_list == one_name_class_list
150+
151+
152+
def test_delitem_not_present(two_name_class_list: 'ClassList') -> None:
153+
"""If we use the del operator to delete an index out of range, we should raise an IndexError."""
154+
class_list = two_name_class_list
155+
with pytest.raises(IndexError, match=re.escape("list assignment index out of range")):
156+
del class_list[2]
157+
158+
145159
@pytest.mark.parametrize("added_list", [
146160
(ClassList(InputAttributes(name='Eve'))),
147161
([InputAttributes(name='Eve')]),

tests/test_custom_errors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Test the utils.custom_errors module."""
2+
3+
from pydantic import create_model, ValidationError
4+
import pytest
5+
6+
import RAT.utils.custom_errors
7+
8+
9+
def test_formatted_pydantic_error() -> None:
10+
"""When a pytest ValidationError is raised we should be able to take it and construct a formatted string."""
11+
12+
# Create a custom pydantic model for the test
13+
TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a'))
14+
15+
with pytest.raises(ValidationError) as exc_info:
16+
TestModel(int_field='string', str_field=5)
17+
18+
error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value)
19+
assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to '
20+
'parse string as an integer\nstr_field\n Input should be a valid string')

0 commit comments

Comments
 (0)