88
99import RATapi
1010import RATapi .wrappers
11- from RATapi .inputs import check_indices , make_controls , make_input , make_problem
11+ from RATapi .inputs import FileHandles , check_indices , make_controls , make_input , make_problem
1212from RATapi .rat_core import Checks , Control , Limits , Priors , ProblemDefinition
1313from RATapi .utils .enums import (
1414 BackgroundActions ,
@@ -133,6 +133,7 @@ def standard_layers_problem():
133133 problem .contrastScalefactors = [1 ]
134134 problem .contrastBackgroundParams = [[1 ]]
135135 problem .contrastBackgroundActions = [BackgroundActions .Add ]
136+ problem .contrastBackgroundTypes = ["constant" ]
136137 problem .contrastResolutionParams = [1 ]
137138 problem .contrastCustomFiles = [float ("NaN" )]
138139 problem .contrastDomainRatios = [0 ]
@@ -155,6 +156,7 @@ def standard_layers_problem():
155156 [6.2e-06 , 6.35e-06 ],
156157 [0.01 , 0.05 ],
157158 ]
159+ problem .customFiles = FileHandles ([])
158160
159161 return problem
160162
@@ -181,6 +183,7 @@ def domains_problem():
181183 problem .contrastScalefactors = [1 ]
182184 problem .contrastBackgroundParams = [[1 ]]
183185 problem .contrastBackgroundActions = [BackgroundActions .Add ]
186+ problem .contrastBackgroundTypes = ["constant" ]
184187 problem .contrastResolutionParams = [1 ]
185188 problem .contrastCustomFiles = [float ("NaN" )]
186189 problem .contrastDomainRatios = [1 ]
@@ -204,6 +207,7 @@ def domains_problem():
204207 [0.01 , 0.05 ],
205208 [0.4 , 0.6 ],
206209 ]
210+ problem .customFiles = FileHandles ([])
207211
208212 return problem
209213
@@ -230,6 +234,7 @@ def custom_xy_problem():
230234 problem .contrastScalefactors = [1 ]
231235 problem .contrastBackgroundParams = [[1 ]]
232236 problem .contrastBackgroundActions = [BackgroundActions .Add ]
237+ problem .contrastBackgroundTypes = ["constant" ]
233238 problem .contrastResolutionParams = [1 ]
234239 problem .contrastCustomFiles = [1 ]
235240 problem .contrastDomainRatios = [0 ]
@@ -252,6 +257,9 @@ def custom_xy_problem():
252257 [6.2e-06 , 6.35e-06 ],
253258 [0.01 , 0.05 ],
254259 ]
260+ problem .customFiles = FileHandles (
261+ [RATapi .models .CustomFile (name = "Test Custom File" , filename = "cpp_test.dll" , language = "cpp" )]
262+ )
255263
256264 return problem
257265
@@ -552,81 +560,98 @@ def test_make_problem(test_project, test_problem, test_check, request) -> None:
552560 check_problem_equal (problem , test_problem )
553561
554562
555- @pytest .mark .parametrize (
556- "test_problem" ,
557- [
558- "standard_layers_problem" ,
559- "custom_xy_problem" ,
560- "domains_problem" ,
561- ],
562- )
563- def test_check_indices (test_problem , request ) -> None :
564- """The check_indices routine should not raise errors for a properly defined ProblemDefinition object."""
565- test_problem = request .getfixturevalue (test_problem )
566-
567- check_indices (test_problem )
568-
569-
570- @pytest .mark .parametrize (
571- ["test_problem" , "index_list" , "bad_value" ],
572- [
573- ("standard_layers_problem" , "contrastBulkIns" , [0.0 ]),
574- ("standard_layers_problem" , "contrastBulkIns" , [2.0 ]),
575- ("standard_layers_problem" , "contrastBulkOuts" , [0.0 ]),
576- ("standard_layers_problem" , "contrastBulkOuts" , [2.0 ]),
577- ("standard_layers_problem" , "contrastScalefactors" , [0.0 ]),
578- ("standard_layers_problem" , "contrastScalefactors" , [2.0 ]),
579- # ("standard_layers_problem", "contrastBackgroundParams", [0.0]),
580- # ("standard_layers_problem", "contrastBackgroundParams", [2.0]),
581- ("standard_layers_problem" , "contrastResolutionParams" , [0.0 ]),
582- ("standard_layers_problem" , "contrastResolutionParams" , [2.0 ]),
583- ("custom_xy_problem" , "contrastBulkIns" , [0.0 ]),
584- ("custom_xy_problem" , "contrastBulkIns" , [2.0 ]),
585- ("custom_xy_problem" , "contrastBulkOuts" , [0.0 ]),
586- ("custom_xy_problem" , "contrastBulkOuts" , [2.0 ]),
587- ("custom_xy_problem" , "contrastScalefactors" , [0.0 ]),
588- ("custom_xy_problem" , "contrastScalefactors" , [2.0 ]),
589- # ("custom_xy_problem", "contrastBackgroundParams", [0.0]),
590- # ("custom_xy_problem", "contrastBackgroundParams", [2.0]),
591- ("custom_xy_problem" , "contrastResolutionParams" , [0.0 ]),
592- ("custom_xy_problem" , "contrastResolutionParams" , [2.0 ]),
593- ("domains_problem" , "contrastBulkIns" , [0.0 ]),
594- ("domains_problem" , "contrastBulkIns" , [2.0 ]),
595- ("domains_problem" , "contrastBulkOuts" , [0.0 ]),
596- ("domains_problem" , "contrastBulkOuts" , [2.0 ]),
597- ("domains_problem" , "contrastScalefactors" , [0.0 ]),
598- ("domains_problem" , "contrastScalefactors" , [2.0 ]),
599- ("domains_problem" , "contrastDomainRatios" , [0.0 ]),
600- ("domains_problem" , "contrastDomainRatios" , [2.0 ]),
601- # ("domains_problem", "contrastBackgroundParams", [0.0]),
602- # ("domains_problem", "contrastBackgroundParams", [2.0]),
603- ("domains_problem" , "contrastResolutionParams" , [0.0 ]),
604- ("domains_problem" , "contrastResolutionParams" , [2.0 ]),
605- ],
606- )
607- def test_check_indices_error (test_problem , index_list , bad_value , request ) -> None :
608- """The check_indices routine should raise an IndexError if a contrast list contains an index that is out of the
609- range of the corresponding parameter list in a ProblemDefinition object.
610- """
611- param_list = {
612- "contrastBulkIns" : "bulkIns" ,
613- "contrastBulkOuts" : "bulkOuts" ,
614- "contrastScalefactors" : "scalefactors" ,
615- "contrastDomainRatios" : "domainRatios" ,
616- "contrastBackgroundParams" : "backgroundParams" ,
617- "contrastResolutionParams" : "resolutionParams" ,
618- }
563+ @pytest .mark .parametrize ("test_problem" , ["standard_layers_problem" , "custom_xy_problem" , "domains_problem" ])
564+ class TestCheckIndices :
565+ """Tests for check_indices over a set of three test problems."""
619566
620- test_problem = request .getfixturevalue (test_problem )
621- setattr (test_problem , index_list , bad_value )
567+ def test_check_indices (self , test_problem , request ) -> None :
568+ """The check_indices routine should not raise errors for a properly defined ProblemDefinition object."""
569+ test_problem = request .getfixturevalue (test_problem )
622570
623- with pytest .raises (
624- IndexError ,
625- match = f'The problem field "{ index_list } " contains: { bad_value [0 ]} , which lie '
626- f'outside of the range of "{ param_list [index_list ]} "' ,
627- ):
628571 check_indices (test_problem )
629572
573+ @pytest .mark .parametrize (
574+ "index_list" ,
575+ [
576+ "contrastBulkIns" ,
577+ "contrastBulkOuts" ,
578+ "contrastScalefactors" ,
579+ "contrastDomainRatios" ,
580+ "contrastResolutionParams" ,
581+ ],
582+ )
583+ @pytest .mark .parametrize ("bad_value" , ([0.0 ], [2.0 ]))
584+ def test_check_indices_error (self , test_problem , index_list , bad_value , request ) -> None :
585+ """The check_indices routine should raise an IndexError if a contrast list contains an index that is out of the
586+ range of the corresponding parameter list in a ProblemDefinition object.
587+ """
588+ param_list = {
589+ "contrastBulkIns" : "bulkIns" ,
590+ "contrastBulkOuts" : "bulkOuts" ,
591+ "contrastScalefactors" : "scalefactors" ,
592+ "contrastDomainRatios" : "domainRatios" ,
593+ "contrastResolutionParams" : "resolutionParams" ,
594+ }
595+ if (test_problem != "domains_problem" ) and (index_list == "contrastDomainRatios" ):
596+ # we expect this to not raise an error for non-domains problems as domainRatios is empty
597+ pytest .xfail ()
598+
599+ test_problem = request .getfixturevalue (test_problem )
600+ setattr (test_problem , index_list , bad_value )
601+
602+ with pytest .raises (
603+ IndexError ,
604+ match = f'The problem field "{ index_list } " contains: { bad_value [0 ]} , which lies '
605+ f'outside of the range of "{ param_list [index_list ]} "' ,
606+ ):
607+ check_indices (test_problem )
608+
609+ @pytest .mark .parametrize ("background_type" , ["constant" , "data" , "function" ])
610+ @pytest .mark .parametrize ("bad_value" , ([[0.0 ]], [[2.0 ]]))
611+ def test_background_params_source_indices (self , test_problem , background_type , bad_value , request ):
612+ """check_indices should raise an IndexError for bad sources in the nested list contrastBackgroundParams."""
613+ test_problem = request .getfixturevalue (test_problem )
614+ test_problem .contrastBackgroundParams = bad_value
615+ test_problem .contrastBackgroundTypes = [background_type ]
616+
617+ source_param_lists = {
618+ "constant" : "backgroundParams" ,
619+ "data" : "data" ,
620+ "function" : "customFiles" ,
621+ }
622+
623+ with pytest .raises (
624+ IndexError ,
625+ match = f'Entry 0 of contrastBackgroundParams has type "{ background_type } " '
626+ f"and source index { bad_value [0 ][0 ]} , "
627+ f'which is outside the range of "{ source_param_lists [background_type ]} ".' ,
628+ ):
629+ check_indices (test_problem )
630+
631+ @pytest .mark .parametrize (
632+ "bad_value" ,
633+ (
634+ [[1.0 , 0.0 ]],
635+ [[1.0 , 2.0 ]],
636+ [[1.0 , 1.0 , 2.0 ]],
637+ [[1.0 ], [1.0 , 0.0 ]],
638+ ),
639+ )
640+ def test_background_params_value_indices (self , test_problem , bad_value , request ):
641+ """check_indices should raise an IndexError for bad values in the nested list contrastBackgroundParams."""
642+ test_problem = request .getfixturevalue (test_problem )
643+ test_problem .contrastBackgroundParams = bad_value
644+
645+ if len (bad_value ) > 1 :
646+ test_problem .contrastBackgroundTypes .append ("constant" )
647+
648+ with pytest .raises (
649+ IndexError ,
650+ match = f"Entry { len (bad_value )- 1 } of contrastBackgroundParams contains: { bad_value [- 1 ][- 1 ]} "
651+ f', which lies outside of the range of "backgroundParams"' ,
652+ ):
653+ check_indices (test_problem )
654+
630655
631656def test_get_python_handle ():
632657 path = pathlib .Path (__file__ ).parent .resolve ()
0 commit comments