11"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""
22
3+ import copy
4+ import functools
35import numpy as np
4- from pydantic import BaseModel , FieldValidationInfo , field_validator , model_validator
5- from typing import Any
6+ from pydantic import BaseModel , FieldValidationInfo , field_validator , model_validator , ValidationError
7+ from typing import Any , Callable
68
79from RAT .classlist import ClassList
810import RAT .models
11+ from RAT .utils .custom_errors import formatted_pydantic_error
912
1013try :
1114 from enum import StrEnum
@@ -46,6 +49,27 @@ class Geometries(StrEnum):
4649 'contrasts' : 'Contrast'
4750 }
4851
52+ values_defined_in = {'backgrounds.value_1' : 'background_parameters' ,
53+ 'backgrounds.value_2' : 'background_parameters' ,
54+ 'backgrounds.value_3' : 'background_parameters' ,
55+ 'backgrounds.value_4' : 'background_parameters' ,
56+ 'backgrounds.value_5' : 'background_parameters' ,
57+ 'resolutions.value_1' : 'resolution_parameters' ,
58+ 'resolutions.value_2' : 'resolution_parameters' ,
59+ 'resolutions.value_3' : 'resolution_parameters' ,
60+ 'resolutions.value_4' : 'resolution_parameters' ,
61+ 'resolutions.value_5' : 'resolution_parameters' ,
62+ 'layers.thickness' : 'parameters' ,
63+ 'layers.SLD' : 'parameters' ,
64+ 'layers.roughness' : 'parameters' ,
65+ 'contrasts.data' : 'data' ,
66+ 'contrasts.background' : 'backgrounds' ,
67+ 'contrasts.nba' : 'bulk_in' ,
68+ 'contrasts.nbs' : 'bulk_out' ,
69+ 'contrasts.scalefactor' : 'scalefactors' ,
70+ 'contrasts.resolution' : 'resolutions' ,
71+ }
72+
4973
5074class Project (BaseModel , validate_assignment = True , extra = 'forbid' , arbitrary_types_allowed = True ):
5175 """Defines the input data for a reflectivity calculation in RAT.
@@ -109,7 +133,9 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
109133 return value
110134
111135 def model_post_init (self , __context : Any ) -> None :
112- """Initialises the class in the ClassLists for empty data fields, and sets protected parameters."""
136+ """Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
137+ ClassList routines to control revalidation.
138+ """
113139 for field_name , model in model_in_classlist .items ():
114140 field = getattr (self , field_name )
115141 if not hasattr (field , "_class_handle" ):
@@ -119,6 +145,17 @@ def model_post_init(self, __context: Any) -> None:
119145 fit = True , prior_type = RAT .models .Priors .Uniform , mu = 0 ,
120146 sigma = np .inf ))
121147
148+ # Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
149+ # model, handle errors and reset previous values if necessary.
150+ class_lists = ['parameters' , 'bulk_in' , 'bulk_out' , 'qz_shifts' , 'scalefactors' , 'background_parameters' ,
151+ 'backgrounds' , 'resolution_parameters' , 'resolutions' , 'custom_files' , 'data' , 'layers' ,
152+ 'contrasts' ]
153+ methods_to_wrap = ['_setitem' , '_delitem' , '_iadd' , 'append' , 'insert' , 'pop' , 'remove' , 'clear' , 'extend' ]
154+ for class_list in class_lists :
155+ attribute = getattr (self , class_list )
156+ for methodName in methods_to_wrap :
157+ setattr (attribute , methodName , self ._classlist_wrapper (attribute , getattr (attribute , methodName )))
158+
122159 @model_validator (mode = 'after' )
123160 def cross_check_model_values (self ) -> 'Project' :
124161 """Certain model fields should contain values defined elsewhere in the project."""
@@ -170,5 +207,43 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
170207 for field in field_list :
171208 value = getattr (model , field )
172209 if value and value not in allowed_values :
173- setattr (model , field , '' )
174- raise ValueError (f'The parameter "{ value } " has not been defined in the list of allowed values.' )
210+ raise ValueError (f'The value "{ value } " in the "{ field } " field of "{ attribute } " must be defined in '
211+ f'"{ values_defined_in [attribute + "." + field ]} ".' )
212+
213+ def _classlist_wrapper (self , class_list : 'ClassList' , func : Callable ):
214+ """Defines the function used to wrap around ClassList routines to force revalidation.
215+
216+ Parameters
217+ ----------
218+ class_list : ClassList
219+ The ClassList defined in the "Project" model that is being modified.
220+ func : Callable
221+ The routine being wrapped.
222+
223+ Returns
224+ -------
225+ wrapped_func : Callable
226+ The wrapped routine.
227+ """
228+ @functools .wraps (func )
229+ def wrapped_func (* args , ** kwargs ):
230+ """Run the given function and then revalidate the "Project" model. If any exception is raised, restore
231+ the previous state of the given ClassList and report details of the exception.
232+ """
233+ previous_state = copy .deepcopy (getattr (class_list , 'data' ))
234+ return_value = None
235+ try :
236+ return_value = func (* args , ** kwargs )
237+ Project .model_validate (self )
238+ except ValidationError as e :
239+ setattr (class_list , 'data' , previous_state )
240+ error_string = formatted_pydantic_error (e )
241+ # Use ANSI escape sequences to print error text in red
242+ print ('\033 [31m' + error_string + '\033 [0m' )
243+ except (TypeError , ValueError ):
244+ setattr (class_list , 'data' , previous_state )
245+ raise
246+ finally :
247+ del previous_state
248+ return return_value
249+ return wrapped_func
0 commit comments