11"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""
22
3+ import collections
4+ import contextlib
35import copy
46import functools
57import numpy as np
@@ -46,7 +48,7 @@ class Geometries(StrEnum):
4648 'custom_files' : 'CustomFile' ,
4749 'data' : 'Data' ,
4850 'layers' : 'Layer' ,
49- 'contrasts' : 'Contrast'
51+ 'contrasts' : 'Contrast' ,
5052 }
5153
5254values_defined_in = {'backgrounds.value_1' : 'background_parameters' ,
@@ -70,6 +72,23 @@ class Geometries(StrEnum):
7072 'contrasts.resolution' : 'resolutions' ,
7173 }
7274
75+ AllFields = collections .namedtuple ('AllFields' , ['attribute' , 'fields' ])
76+ model_names_used_in = {'background_parameters' : AllFields ('backgrounds' , ['value_1' , 'value_2' , 'value_3' , 'value_4' ,
77+ 'value_5' ]),
78+ 'resolution_parameters' : AllFields ('resolutions' , ['value_1' , 'value_2' , 'value_3' , 'value_4' ,
79+ 'value_5' ]),
80+ 'parameters' : AllFields ('layers' , ['thickness' , 'SLD' , 'roughness' ]),
81+ 'data' : AllFields ('contrasts' , ['data' ]),
82+ 'backgrounds' : AllFields ('contrasts' , ['background' ]),
83+ 'bulk_in' : AllFields ('contrasts' , ['nba' ]),
84+ 'bulk_out' : AllFields ('contrasts' , ['nbs' ]),
85+ 'scalefactors' : AllFields ('contrasts' , ['scalefactor' ]),
86+ 'resolutions' : AllFields ('contrasts' , ['resolution' ]),
87+ }
88+
89+ class_lists = ['parameters' , 'bulk_in' , 'bulk_out' , 'qz_shifts' , 'scalefactors' , 'background_parameters' , 'backgrounds' ,
90+ 'resolution_parameters' , 'resolutions' , 'custom_files' , 'data' , 'layers' , 'contrasts' ]
91+
7392
7493class Project (BaseModel , validate_assignment = True , extra = 'forbid' , arbitrary_types_allowed = True ):
7594 """Defines the input data for a reflectivity calculation in RAT.
@@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
120139 layers : ClassList = ClassList ()
121140 contrasts : ClassList = ClassList ()
122141
142+ _all_names : dict
143+
123144 @field_validator ('parameters' , 'bulk_in' , 'bulk_out' , 'qz_shifts' , 'scalefactors' , 'background_parameters' ,
124145 'backgrounds' , 'resolution_parameters' , 'resolutions' , 'custom_files' , 'data' , 'layers' ,
125146 'contrasts' )
@@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
133154 return value
134155
135156 def model_post_init (self , __context : Any ) -> None :
136- """Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
137- ClassList routines to control revalidation.
157+ """Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all
158+ defined parameters and wraps ClassList routines to control revalidation.
138159 """
139160 for field_name , model in model_in_classlist .items ():
140161 field = getattr (self , field_name )
@@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None:
145166 fit = True , prior_type = RAT .models .Priors .Uniform , mu = 0 ,
146167 sigma = np .inf ))
147168
169+ self ._all_names = self .get_all_names ()
170+
148171 # Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
149172 # model, handle errors and reset previous values if necessary.
150- class_lists = ['parameters' , 'bulk_in' , 'bulk_out' , 'qz_shifts' , 'scalefactors' , 'background_parameters' ,
151- 'backgrounds' , 'resolution_parameters' , 'resolutions' , 'custom_files' , 'data' , 'layers' ,
152- 'contrasts' ]
153173 methods_to_wrap = ['_setitem' , '_delitem' , '_iadd' , 'append' , 'insert' , 'pop' , 'remove' , 'clear' , 'extend' ]
154174 for class_list in class_lists :
155175 attribute = getattr (self , class_list )
156176 for methodName in methods_to_wrap :
157177 setattr (attribute , methodName , self ._classlist_wrapper (attribute , getattr (attribute , methodName )))
158178
179+ @model_validator (mode = 'after' )
180+ def update_renamed_models (self ) -> 'Project' :
181+ """When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
182+ for class_list in class_lists :
183+ old_names = self ._all_names [class_list ]
184+ new_names = getattr (self , class_list ).get_names ()
185+ if len (old_names ) == len (new_names ):
186+ name_diff = [(old , new ) for (old , new ) in zip (old_names , new_names ) if old != new ]
187+ for (old_name , new_name ) in name_diff :
188+ with contextlib .suppress (KeyError ):
189+ model_names_list = getattr (self , model_names_used_in [class_list ].attribute )
190+ all_matches = model_names_list .get_all_matches (old_name )
191+ for (index , field ) in all_matches :
192+ setattr (model_names_list [index ], field , new_name )
193+ self ._all_names = self .get_all_names ()
194+ return self
195+
159196 @model_validator (mode = 'after' )
160197 def cross_check_model_values (self ) -> 'Project' :
161198 """Certain model fields should contain values defined elsewhere in the project."""
@@ -185,6 +222,10 @@ def __repr__(self):
185222 output += value .value + '\n \n '
186223 return output
187224
225+ def get_all_names (self ):
226+ """Record the names of all models defined in the project."""
227+ return {class_list : getattr (self , class_list ).get_names () for class_list in class_lists }
228+
188229 def check_allowed_values (self , attribute : str , field_list : list [str ], allowed_values : list [str ]) -> None :
189230 """Check the values of the given fields in the given model are in the supplied list of allowed values.
190231
0 commit comments