Skip to content

Commit 2c923e2

Browse files
committed
Adds names to "test_inputs.py"
1 parent f852ebc commit 2c923e2

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

tests/test_inputs.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import RATapi
1010
import RATapi.wrappers
1111
from RATapi.inputs import FileHandles, check_indices, make_controls, make_input, make_problem
12-
from RATapi.rat_core import Checks, Control, Limits, Priors, ProblemDefinition
12+
from RATapi.rat_core import Checks, Control, Limits, NameStore, Priors, ProblemDefinition
1313
from RATapi.utils.enums import (
1414
BackgroundActions,
1515
BoundHandling,
@@ -112,7 +112,40 @@ def custom_xy_project():
112112

113113

114114
@pytest.fixture
115-
def standard_layers_problem():
115+
def test_names():
116+
"""The expected NameStore object from "standard_layers_project", "domains_project" and "custom_xy_project"."""
117+
names = NameStore()
118+
names.params = ["Substrate Roughness", "Test Thickness", "Test SLD", "Test Roughness"]
119+
names.backgroundParams = ["Background Param 1"]
120+
names.scalefactors = ["Scalefactor 1"]
121+
names.qzshifts = []
122+
names.bulkIns = ["SLD Air"]
123+
names.bulkOuts = ["SLD D2O"]
124+
names.resolutionParams = ["Resolution Param 1"]
125+
names.domainRatios = []
126+
names.contrasts = ["Test Contrast"]
127+
128+
return names
129+
130+
131+
@pytest.fixture
132+
def test_checks():
133+
"""The expected checks object from "standard_layers_project", "domains_project" and "custom_xy_project"."""
134+
checks = Checks()
135+
checks.params = [1, 0, 0, 0]
136+
checks.backgroundParams = [0]
137+
checks.scalefactors = [0]
138+
checks.qzshifts = []
139+
checks.bulkIns = [0]
140+
checks.bulkOuts = [0]
141+
checks.resolutionParams = [0]
142+
checks.domainRatios = []
143+
144+
return checks
145+
146+
147+
@pytest.fixture
148+
def standard_layers_problem(test_names, test_checks):
116149
"""The expected problem object from "standard_layers_project"."""
117150
problem = ProblemDefinition()
118151
problem.TF = Calculations.Normal
@@ -165,12 +198,14 @@ def standard_layers_problem():
165198
[0.01, 0.05],
166199
]
167200
problem.customFiles = FileHandles([])
201+
problem.names = test_names
202+
problem.checks = test_checks
168203

169204
return problem
170205

171206

172207
@pytest.fixture
173-
def domains_problem():
208+
def domains_problem(test_names, test_checks):
174209
"""The expected problem object from "domains_project"."""
175210
problem = ProblemDefinition()
176211
problem.TF = Calculations.Domains
@@ -224,12 +259,15 @@ def domains_problem():
224259
[0.4, 0.6],
225260
]
226261
problem.customFiles = FileHandles([])
262+
problem.names = test_names
263+
problem.names.domainRatios = ["Domain Ratio 1"]
264+
problem.checks = test_checks
227265

228266
return problem
229267

230268

231269
@pytest.fixture
232-
def custom_xy_problem():
270+
def custom_xy_problem(test_names, test_checks):
233271
"""The expected problem object from "custom_xy_project"."""
234272
problem = ProblemDefinition()
235273
problem.TF = Calculations.Normal
@@ -284,6 +322,8 @@ def custom_xy_problem():
284322
problem.customFiles = FileHandles(
285323
[RATapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")]
286324
)
325+
problem.names = test_names
326+
problem.checks = test_checks
287327

288328
return problem
289329

@@ -484,22 +524,6 @@ def custom_xy_controls():
484524
return controls
485525

486526

487-
@pytest.fixture
488-
def test_checks():
489-
"""The expected checks object from "standard_layers_project", "domains_project" and "custom_xy_project"."""
490-
checks = Checks()
491-
checks.params = [1, 0, 0, 0]
492-
checks.backgroundParams = [0]
493-
checks.scalefactors = [0]
494-
checks.qzshifts = []
495-
checks.bulkIns = [0]
496-
checks.bulkOuts = [0]
497-
checks.resolutionParams = [0]
498-
checks.domainRatios = []
499-
500-
return checks
501-
502-
503527
@pytest.mark.parametrize(
504528
["test_project", "test_problem", "test_limits", "test_priors", "test_controls"],
505529
[
@@ -548,7 +572,7 @@ def test_make_input(test_project, test_problem, test_limits, test_priors, test_c
548572
]
549573

550574
problem, limits, priors, controls = make_input(test_project, RATapi.Controls())
551-
problem = pickle.loads(pickle.dumps(problem))
575+
# problem = pickle.loads(pickle.dumps(problem))
552576
check_problem_equal(problem, test_problem)
553577

554578
limits = pickle.loads(pickle.dumps(limits))
@@ -757,11 +781,26 @@ def check_problem_equal(actual_problem, expected_problem) -> None:
757781
"fitLimits",
758782
"otherLimits",
759783
]
784+
checks_fields = [
785+
"params",
786+
"backgroundParams",
787+
"scalefactors",
788+
"qzshifts",
789+
"bulkIns",
790+
"bulkOuts",
791+
"resolutionParams",
792+
"domainRatios",
793+
]
794+
names_fields = [*checks_fields, "contrasts"]
760795

761796
for scalar_field in scalar_fields:
762797
assert getattr(actual_problem, scalar_field) == getattr(expected_problem, scalar_field)
763798
for array_field in array_fields:
764799
assert np.all(getattr(actual_problem, array_field) == getattr(expected_problem, array_field))
800+
for field in names_fields:
801+
assert getattr(actual_problem.names, field) == getattr(expected_problem.names, field)
802+
for field in checks_fields:
803+
assert (getattr(actual_problem.checks, field) == getattr(expected_problem.checks, field)).all()
765804

766805
# Data field is a numpy array
767806
assert [

0 commit comments

Comments
 (0)