44import os
55import pathlib
66from typing import Callable , Union
7+
78import numpy as np
9+
810import RATapi
911import RATapi .controls
1012import RATapi .wrappers
11- from RATapi .rat_core import NameStore , Checks , Control , Limits , Priors , ProblemDefinition
13+ from RATapi .rat_core import Checks , Control , Limits , NameStore , Priors , ProblemDefinition
1214from RATapi .utils .enums import Calculations , Languages , LayerModels , TypeOptions
1315
1416
@@ -49,7 +51,7 @@ class FileHandles:
4951 def __init__ (self , files = None ):
5052 self .index = 0
5153 self .files = [] if files is None else [file .dict () for file in files ]
52-
54+
5355 def __iter__ (self ):
5456 self .index = 0
5557 return self
@@ -89,9 +91,7 @@ def __next__(self):
8991 raise StopIteration
9092
9193
92- def make_input (
93- project : RATapi .Project , controls : RATapi .Controls
94- ) -> tuple [ProblemDefinition , Limits , Priors , Control ]:
94+ def make_input (project : RATapi .Project , controls : RATapi .Controls ) -> tuple [ProblemDefinition , Limits , Priors , Control ]:
9595 """Constructs the inputs required for the compiled RAT code using the data defined in the input project and
9696 controls.
9797
@@ -192,8 +192,6 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
192192
193193 """
194194 hydrate_id = {"bulk in" : 1 , "bulk out" : 2 }
195- action_id = {"add" : 1 , "subtract" : 2 }
196-
197195
198196 # Set contrast parameters according to model type
199197 if project .model == LayerModels .StandardLayers :
@@ -208,7 +206,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
208206 ]
209207 else :
210208 contrast_models = [[]] * len (project .contrasts )
211-
209+
212210 # Set contrast parameters according to model type
213211 if project .model == LayerModels .StandardLayers :
214212 contrast_custom_files = [float ("NaN" )] * len (project .contrasts )
@@ -230,7 +228,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
230228 layer_params .append (hydrate_id [layer .hydrate_with ])
231229
232230 layer_details .append (layer_params )
233-
231+
234232 for contrast in project .contrasts :
235233 background = project .backgrounds [contrast .background ]
236234 contrast_background_types .append (background .type )
@@ -251,10 +249,10 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
251249 contrast_resolution_params .append (- 1 )
252250 else :
253251 contrast_resolution_params .append (project .resolution_parameters .index (resolution .value_1 , True ))
254-
252+
255253 data_index = project .data .index (contrast .data )
256254 data = project .data [data_index ].data
257- all_data .append (np .column_stack ((data , np .zeros ((data .shape [0 ], 6 - data .shape [1 ])))))
255+ all_data .append (np .column_stack ((data , np .zeros ((data .shape [0 ], 6 - data .shape [1 ])))))
258256 data_range = project .data [data_index ].data_range
259257 simulation_range = project .data [data_index ].simulation_range
260258
@@ -267,7 +265,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
267265 simulation_limits .append (simulation_range )
268266 else :
269267 simulation_limits .append ([0.0 , 0.0 ])
270-
268+
271269 problem = ProblemDefinition ()
272270
273271 problem .TF = project .calculation
@@ -283,7 +281,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
283281 problem .repeatLayers = [[0 , 1 ]] * len (project .contrasts ) # This is marked as "to do" in RAT
284282 problem .contrastBackgroundParams = contrast_background_params
285283 problem .contrastBackgroundTypes = contrast_background_types
286- problem .contrastBackgroundActions = [contrast .background_action for contrast in project .contrasts ]
284+ problem .contrastBackgroundActions = [contrast .background_action for contrast in project .contrasts ]
287285 problem .contrastQzshifts = [1 ] * len (project .contrasts ) # This is marked as "to do" in RAT
288286 problem .contrastScalefactors = [
289287 project .scalefactors .index (contrast .scalefactor , True ) for contrast in project .contrasts
@@ -304,15 +302,15 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
304302 problem .customFiles = FileHandles (project .custom_files )
305303 problem .modelType = project .model
306304 problem .contrastCustomFiles = contrast_custom_files
307-
305+
308306 problem .contrastDomainRatios = [
309307 project .domain_ratios .index (contrast .domain_ratio , True ) if hasattr (contrast , "domain_ratio" ) else 0
310308 for contrast in project .contrasts
311309 ]
312310
313311 problem .domainRatios = [param .value for param in project .domain_ratios ]
314312 problem .numberOfDomainContrasts = len (project .domain_contrasts )
315-
313+
316314 domain_contrast_models = [
317315 [project .layers .index (layer , True ) for layer in domain_contrast .model ]
318316 for domain_contrast in project .domain_contrasts
@@ -345,7 +343,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
345343 for param in getattr (project , class_list )
346344 if not param .fit
347345 ]
348-
346+
349347 # Names
350348 problem .names = NameStore ()
351349 problem .names .params = [param .name for param in project .parameters ]
@@ -357,7 +355,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
357355 problem .names .resolutionParams = [param .name for param in project .resolution_parameters ]
358356 problem .names .domainRatios = [param .name for param in project .domain_ratios ]
359357 problem .names .contrasts = [contrast .name for contrast in project .contrasts ]
360-
358+
361359 # Checks
362360 problem .checks = checks
363361 check_indices (problem )
0 commit comments