1010import RATapi
1111import RATapi .controls
1212import RATapi .wrappers
13- from RATapi .rat_core import Checks , Control , Limits , NameStore , Priors , ProblemDefinition
13+ from RATapi .rat_core import Checks , Control , Limits , NameStore , ProblemDefinition
1414from RATapi .utils .enums import Calculations , Languages , LayerModels , TypeOptions
1515
16+ parameter_field = {
17+ "parameters" : "params" ,
18+ "bulk_in" : "bulkIns" ,
19+ "bulk_out" : "bulkOuts" ,
20+ "scalefactors" : "scalefactors" ,
21+ "domain_ratios" : "domainRatios" ,
22+ "background_parameters" : "backgroundParams" ,
23+ "resolution_parameters" : "resolutionParams" ,
24+ }
25+
1626
1727def get_python_handle (file_name : str , function_name : str , path : Union [str , pathlib .Path ] = "" ) -> Callable :
1828 """Get the function handle from a function defined in a python module located anywhere within the filesystem.
@@ -94,7 +104,7 @@ def __len__(self):
94104 return len (self .files )
95105
96106
97- def make_input (project : RATapi .Project , controls : RATapi .Controls ) -> tuple [ProblemDefinition , Limits , Priors , Control ]:
107+ def make_input (project : RATapi .Project , controls : RATapi .Controls ) -> tuple [ProblemDefinition , Limits , Control ]:
98108 """Constructs the inputs required for the compiled RAT code using the data defined in the input project and
99109 controls.
100110
@@ -111,65 +121,32 @@ def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[Prob
111121 The problem input used in the compiled RAT code.
112122 limits : RAT.rat_core.Limits
113123 A list of min/max values for each parameter defined in the project.
114- priors : RAT.rat_core.Priors
115- The priors defined for each parameter in the project.
116124 cpp_controls : RAT.rat_core.Control
117125 The controls object used in the compiled RAT code.
118126
119127 """
120- parameter_field = {
121- "parameters" : "params" ,
122- "bulk_in" : "bulkIns" ,
123- "bulk_out" : "bulkOuts" ,
124- "scalefactors" : "scalefactors" ,
125- "domain_ratios" : "domainRatios" ,
126- "background_parameters" : "backgroundParams" ,
127- "resolution_parameters" : "resolutionParams" ,
128- }
129-
130- prior_id = {"uniform" : 1 , "gaussian" : 2 , "jeffreys" : 3 }
131-
132- checks = Checks ()
133128 limits = Limits ()
134- priors = Priors ()
135129
136130 for class_list in RATapi .project .parameter_class_lists :
137- setattr (checks , parameter_field [class_list ], [int (element .fit ) for element in getattr (project , class_list )])
138131 setattr (
139132 limits ,
140133 parameter_field [class_list ],
141134 [[element .min , element .max ] for element in getattr (project , class_list )],
142135 )
143- setattr (
144- priors ,
145- parameter_field [class_list ],
146- [[element .name , element .prior_type , element .mu , element .sigma ] for element in getattr (project , class_list )],
147- )
148136
149- # Use dummy values for qzshifts
150- checks .qzshifts = []
137+ # Use dummy value for qzshifts
151138 limits .qzshifts = []
152- priors .qzshifts = []
153-
154- priors .priorNames = [
155- param .name for class_list in RATapi .project .parameter_class_lists for param in getattr (project , class_list )
156- ]
157- priors .priorValues = [
158- [prior_id [param .prior_type ], param .mu , param .sigma ]
159- for class_list in RATapi .project .parameter_class_lists
160- for param in getattr (project , class_list )
161- ]
162139
163140 if project .model == LayerModels .CustomXY :
164141 controls .calcSldDuringFit = True
165142
166- problem = make_problem (project , checks )
143+ problem = make_problem (project )
167144 cpp_controls = make_controls (controls )
168145
169- return problem , limits , priors , cpp_controls
146+ return problem , limits , cpp_controls
170147
171148
172- def make_problem (project : RATapi .Project , checks : Checks ) -> ProblemDefinition :
149+ def make_problem (project : RATapi .Project ) -> ProblemDefinition :
173150 """Constructs the problem input required for the compiled RAT code.
174151
175152 Parameters
@@ -184,6 +161,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
184161
185162 """
186163 hydrate_id = {"bulk in" : 1 , "bulk out" : 2 }
164+ prior_id = {"uniform" : 1 , "gaussian" : 2 , "jeffreys" : 3 }
187165
188166 # Set contrast parameters according to model type
189167 if project .model == LayerModels .StandardLayers :
@@ -384,21 +362,32 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
384362 for param in getattr (project , class_list )
385363 if not param .fit
386364 ]
365+ problem .priorNames = [
366+ param .name for class_list in RATapi .project .parameter_class_lists for param in getattr (project , class_list )
367+ ]
368+ problem .priorValues = [
369+ [prior_id [param .prior_type ], param .mu , param .sigma ]
370+ for class_list in RATapi .project .parameter_class_lists
371+ for param in getattr (project , class_list )
372+ ]
387373
388374 # Names
389375 problem .names = NameStore ()
390- problem .names .params = [param .name for param in project .parameters ]
391- problem .names .backgroundParams = [param .name for param in project .background_parameters ]
392- problem .names .scalefactors = [param .name for param in project .scalefactors ]
393- problem .names .qzshifts = [] # Placeholder for qzshifts
394- problem .names .bulkIns = [param .name for param in project .bulk_in ]
395- problem .names .bulkOuts = [param .name for param in project .bulk_out ]
396- problem .names .resolutionParams = [param .name for param in project .resolution_parameters ]
397- problem .names .domainRatios = [param .name for param in project .domain_ratios ]
376+ for class_list in RATapi .project .parameter_class_lists :
377+ setattr (problem .names , parameter_field [class_list ], [param .name for param in getattr (project , class_list )])
398378 problem .names .contrasts = [contrast .name for contrast in project .contrasts ]
399379
400380 # Checks
401- problem .checks = checks
381+ problem .checks = Checks ()
382+ for class_list in RATapi .project .parameter_class_lists :
383+ setattr (
384+ problem .checks , parameter_field [class_list ], [int (element .fit ) for element in getattr (project , class_list )]
385+ )
386+
387+ # Use dummy values for qz shifts
388+ problem .names .qzshifts = []
389+ problem .checks .qzshifts = []
390+
402391 check_indices (problem )
403392
404393 return problem
0 commit comments