|
9 | 9 | import RATapi |
10 | 10 | import RATapi.wrappers |
11 | 11 | from RATapi.inputs import FileHandles, check_indices, make_controls, make_input, make_problem |
12 | | -from RATapi.rat_core import Checks, Control, Limits, Priors, ProblemDefinition |
| 12 | +from RATapi.rat_core import Checks, Control, Limits, NameStore, Priors, ProblemDefinition |
13 | 13 | from RATapi.utils.enums import ( |
14 | 14 | BackgroundActions, |
15 | 15 | BoundHandling, |
@@ -112,7 +112,40 @@ def custom_xy_project(): |
112 | 112 |
|
113 | 113 |
|
114 | 114 | @pytest.fixture |
115 | | -def standard_layers_problem(): |
| 115 | +def test_names(): |
| 116 | + """The expected NameStore object from "standard_layers_project", "domains_project" and "custom_xy_project".""" |
| 117 | + names = NameStore() |
| 118 | + names.params = ["Substrate Roughness", "Test Thickness", "Test SLD", "Test Roughness"] |
| 119 | + names.backgroundParams = ["Background Param 1"] |
| 120 | + names.scalefactors = ["Scalefactor 1"] |
| 121 | + names.qzshifts = [] |
| 122 | + names.bulkIns = ["SLD Air"] |
| 123 | + names.bulkOuts = ["SLD D2O"] |
| 124 | + names.resolutionParams = ["Resolution Param 1"] |
| 125 | + names.domainRatios = [] |
| 126 | + names.contrasts = ["Test Contrast"] |
| 127 | + |
| 128 | + return names |
| 129 | + |
| 130 | + |
| 131 | +@pytest.fixture |
| 132 | +def test_checks(): |
| 133 | + """The expected checks object from "standard_layers_project", "domains_project" and "custom_xy_project".""" |
| 134 | + checks = Checks() |
| 135 | + checks.params = [1, 0, 0, 0] |
| 136 | + checks.backgroundParams = [0] |
| 137 | + checks.scalefactors = [0] |
| 138 | + checks.qzshifts = [] |
| 139 | + checks.bulkIns = [0] |
| 140 | + checks.bulkOuts = [0] |
| 141 | + checks.resolutionParams = [0] |
| 142 | + checks.domainRatios = [] |
| 143 | + |
| 144 | + return checks |
| 145 | + |
| 146 | + |
| 147 | +@pytest.fixture |
| 148 | +def standard_layers_problem(test_names, test_checks): |
116 | 149 | """The expected problem object from "standard_layers_project".""" |
117 | 150 | problem = ProblemDefinition() |
118 | 151 | problem.TF = Calculations.Normal |
@@ -165,12 +198,14 @@ def standard_layers_problem(): |
165 | 198 | [0.01, 0.05], |
166 | 199 | ] |
167 | 200 | problem.customFiles = FileHandles([]) |
| 201 | + problem.names = test_names |
| 202 | + problem.checks = test_checks |
168 | 203 |
|
169 | 204 | return problem |
170 | 205 |
|
171 | 206 |
|
172 | 207 | @pytest.fixture |
173 | | -def domains_problem(): |
| 208 | +def domains_problem(test_names, test_checks): |
174 | 209 | """The expected problem object from "domains_project".""" |
175 | 210 | problem = ProblemDefinition() |
176 | 211 | problem.TF = Calculations.Domains |
@@ -224,12 +259,15 @@ def domains_problem(): |
224 | 259 | [0.4, 0.6], |
225 | 260 | ] |
226 | 261 | problem.customFiles = FileHandles([]) |
| 262 | + problem.names = test_names |
| 263 | + problem.names.domainRatios = ["Domain Ratio 1"] |
| 264 | + problem.checks = test_checks |
227 | 265 |
|
228 | 266 | return problem |
229 | 267 |
|
230 | 268 |
|
231 | 269 | @pytest.fixture |
232 | | -def custom_xy_problem(): |
| 270 | +def custom_xy_problem(test_names, test_checks): |
233 | 271 | """The expected problem object from "custom_xy_project".""" |
234 | 272 | problem = ProblemDefinition() |
235 | 273 | problem.TF = Calculations.Normal |
@@ -284,6 +322,8 @@ def custom_xy_problem(): |
284 | 322 | problem.customFiles = FileHandles( |
285 | 323 | [RATapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")] |
286 | 324 | ) |
| 325 | + problem.names = test_names |
| 326 | + problem.checks = test_checks |
287 | 327 |
|
288 | 328 | return problem |
289 | 329 |
|
@@ -484,22 +524,6 @@ def custom_xy_controls(): |
484 | 524 | return controls |
485 | 525 |
|
486 | 526 |
|
487 | | -@pytest.fixture |
488 | | -def test_checks(): |
489 | | - """The expected checks object from "standard_layers_project", "domains_project" and "custom_xy_project".""" |
490 | | - checks = Checks() |
491 | | - checks.params = [1, 0, 0, 0] |
492 | | - checks.backgroundParams = [0] |
493 | | - checks.scalefactors = [0] |
494 | | - checks.qzshifts = [] |
495 | | - checks.bulkIns = [0] |
496 | | - checks.bulkOuts = [0] |
497 | | - checks.resolutionParams = [0] |
498 | | - checks.domainRatios = [] |
499 | | - |
500 | | - return checks |
501 | | - |
502 | | - |
503 | 527 | @pytest.mark.parametrize( |
504 | 528 | ["test_project", "test_problem", "test_limits", "test_priors", "test_controls"], |
505 | 529 | [ |
@@ -548,7 +572,7 @@ def test_make_input(test_project, test_problem, test_limits, test_priors, test_c |
548 | 572 | ] |
549 | 573 |
|
550 | 574 | problem, limits, priors, controls = make_input(test_project, RATapi.Controls()) |
551 | | - problem = pickle.loads(pickle.dumps(problem)) |
| 575 | + # problem = pickle.loads(pickle.dumps(problem)) |
552 | 576 | check_problem_equal(problem, test_problem) |
553 | 577 |
|
554 | 578 | limits = pickle.loads(pickle.dumps(limits)) |
@@ -757,11 +781,26 @@ def check_problem_equal(actual_problem, expected_problem) -> None: |
757 | 781 | "fitLimits", |
758 | 782 | "otherLimits", |
759 | 783 | ] |
| 784 | + checks_fields = [ |
| 785 | + "params", |
| 786 | + "backgroundParams", |
| 787 | + "scalefactors", |
| 788 | + "qzshifts", |
| 789 | + "bulkIns", |
| 790 | + "bulkOuts", |
| 791 | + "resolutionParams", |
| 792 | + "domainRatios", |
| 793 | + ] |
| 794 | + names_fields = [*checks_fields, "contrasts"] |
760 | 795 |
|
761 | 796 | for scalar_field in scalar_fields: |
762 | 797 | assert getattr(actual_problem, scalar_field) == getattr(expected_problem, scalar_field) |
763 | 798 | for array_field in array_fields: |
764 | 799 | assert np.all(getattr(actual_problem, array_field) == getattr(expected_problem, array_field)) |
| 800 | + for field in names_fields: |
| 801 | + assert getattr(actual_problem.names, field) == getattr(expected_problem.names, field) |
| 802 | + for field in checks_fields: |
| 803 | + assert (getattr(actual_problem.checks, field) == getattr(expected_problem.checks, field)).all() |
765 | 804 |
|
766 | 805 | # Data field is a numpy array |
767 | 806 | assert [ |
|
0 commit comments