Skip to content

Commit 5d9529f

Browse files
committed
updated check_indices to be nested
1 parent f609b7e commit 5d9529f

File tree

2 files changed

+163
-85
lines changed

2 files changed

+163
-85
lines changed

RATapi/inputs.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class FileHandles:
5050

5151
def __init__(self, files=None):
5252
self.index = 0
53-
self.files = [] if files is None else [file.dict() for file in files]
53+
self.files = [] if files is None else [file.model_dump() for file in files]
5454

5555
def __iter__(self):
5656
self.index = 0
@@ -90,6 +90,9 @@ def __next__(self):
9090
else:
9191
raise StopIteration
9292

93+
def __len__(self):
94+
return len(self.files)
95+
9396

9497
def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[ProblemDefinition, Limits, Priors, Control]:
9598
"""Constructs the inputs required for the compiled RAT code using the data defined in the input project and
@@ -106,8 +109,6 @@ def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[Prob
106109
-------
107110
problem : RAT.rat_core.ProblemDefinition
108111
The problem input used in the compiled RAT code.
109-
cells : RAT.rat_core.Cells
110-
The set of inputs that are defined in MATLAB as cell arrays.
111112
limits : RAT.rat_core.Limits
112113
A list of min/max values for each parameter defined in the project.
113114
priors : RAT.rat_core.Priors
@@ -232,10 +233,31 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
232233
for contrast in project.contrasts:
233234
background = project.backgrounds[contrast.background]
234235
contrast_background_types.append(background.type)
236+
contrast_background_param = []
235237
if background.type == TypeOptions.Data:
236-
contrast_background_params.append([-1])
238+
contrast_background_param.append(project.data.index(background.source, True))
239+
if background.value_1 != "":
240+
contrast_background_param.append(project.background_parameters.index(background.value_1))
241+
elif background.type == TypeOptions.Function:
242+
contrast_background_param.append(project.custom_files.index(background.source, True))
243+
contrast_background_param.extend(
244+
[
245+
project.background_parameters.index(value, True)
246+
for value in [
247+
background.value_1,
248+
background.value_2,
249+
background.value_3,
250+
background.value_4,
251+
background.value_5,
252+
]
253+
if value != ""
254+
]
255+
)
256+
237257
else:
238-
contrast_background_params.append([project.background_parameters.index(background.source, True)])
258+
contrast_background_param.append(project.background_parameters.index(background.source, True))
259+
260+
contrast_background_params.append(contrast_background_param)
239261

240262
# Set resolution parameters, with -1 used to indicate a data resolution
241263
all_data = []
@@ -412,27 +434,58 @@ def check_indices(problem: ProblemDefinition) -> None:
412434
"bulkOuts": "contrastBulkOuts",
413435
"scalefactors": "contrastScalefactors",
414436
"domainRatios": "contrastDomainRatios",
415-
# "backgroundParams": "contrastBackgroundParams",
416437
"resolutionParams": "contrastResolutionParams",
417438
}
418439

419440
# Check the indices -- note we have switched to 1-based indexing at this point
420441
for params in index_list:
421442
param_list = getattr(problem, params)
422-
if len(param_list) > 0 and not all(
423-
(element > 0 or element == -1) and element <= len(param_list)
424-
for element in getattr(problem, index_list[params])
425-
):
443+
if len(param_list) > 0:
426444
elements = [
427445
element
428446
for element in getattr(problem, index_list[params])
429-
if not ((element > 0 or element == -1) and element <= len(param_list))
447+
if (element != -1) and not (0 < element <= len(param_list))
430448
]
449+
if elements:
450+
raise IndexError(
451+
f'The problem field "{index_list[params]}" contains: {", ".join(str(i) for i in elements)}'
452+
f', which lie{"s"*(len(elements)==1)} outside of the range of "{params}"',
453+
)
454+
455+
# backgroundParams has a different structure, so is handled separately:
456+
# it is of type list[list[int]], where each list[int] is the indices for
457+
# source, value_1, value_2, value_3, value_4, value_5 where they are defined
458+
# e.g. for a data background with offset it is [source value_1], for a function
459+
# with 3 values it is [source value_1 value_2 value_3], etc.
460+
461+
source_param_lists = {
462+
"constant": "backgroundParams",
463+
"data": "data",
464+
"function": "customFiles",
465+
}
466+
467+
for i, background_data in enumerate(problem.contrastBackgroundParams):
468+
background_type = problem.contrastBackgroundTypes[i]
469+
470+
# check source param is in range of the relevant parameter list
471+
param_list = getattr(problem, source_param_lists[background_type])
472+
source_index = background_data[0]
473+
if not 0 < source_index <= len(param_list):
431474
raise IndexError(
432-
f'The problem field "{index_list[params]}" contains: {", ".join(str(i) for i in elements)}'
433-
f', which lie outside of the range of "{params}"',
475+
f'Entry {i} of contrastBackgroundParams has type "{background_type}" '
476+
f"and source index {source_index}, "
477+
f'which is outside the range of "{source_param_lists[background_type]}".'
434478
)
435479

480+
# check value params are in range for background params
481+
if len(background_data) > 1:
482+
elements = [element for element in background_data[1:] if not 0 < element <= len(problem.backgroundParams)]
483+
if elements:
484+
raise IndexError(
485+
f'Entry {i} of contrastBackgroundParams contains: {", ".join(str(i) for i in elements)}'
486+
f', which lie{"s"*(len(elements)==1)} outside of the range of "backgroundParams"',
487+
)
488+
436489

437490
def make_controls(input_controls: RATapi.Controls) -> Control:
438491
"""Converts the controls object to the format required by the compiled RAT code.

tests/test_inputs.py

Lines changed: 97 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import RATapi
1010
import RATapi.wrappers
11-
from RATapi.inputs import check_indices, make_controls, make_input, make_problem
11+
from RATapi.inputs import FileHandles, check_indices, make_controls, make_input, make_problem
1212
from RATapi.rat_core import Checks, Control, Limits, Priors, ProblemDefinition
1313
from RATapi.utils.enums import (
1414
BackgroundActions,
@@ -133,6 +133,7 @@ def standard_layers_problem():
133133
problem.contrastScalefactors = [1]
134134
problem.contrastBackgroundParams = [[1]]
135135
problem.contrastBackgroundActions = [BackgroundActions.Add]
136+
problem.contrastBackgroundTypes = ["constant"]
136137
problem.contrastResolutionParams = [1]
137138
problem.contrastCustomFiles = [float("NaN")]
138139
problem.contrastDomainRatios = [0]
@@ -155,6 +156,7 @@ def standard_layers_problem():
155156
[6.2e-06, 6.35e-06],
156157
[0.01, 0.05],
157158
]
159+
problem.customFiles = FileHandles([])
158160

159161
return problem
160162

@@ -181,6 +183,7 @@ def domains_problem():
181183
problem.contrastScalefactors = [1]
182184
problem.contrastBackgroundParams = [[1]]
183185
problem.contrastBackgroundActions = [BackgroundActions.Add]
186+
problem.contrastBackgroundTypes = ["constant"]
184187
problem.contrastResolutionParams = [1]
185188
problem.contrastCustomFiles = [float("NaN")]
186189
problem.contrastDomainRatios = [1]
@@ -204,6 +207,7 @@ def domains_problem():
204207
[0.01, 0.05],
205208
[0.4, 0.6],
206209
]
210+
problem.customFiles = FileHandles([])
207211

208212
return problem
209213

@@ -230,6 +234,7 @@ def custom_xy_problem():
230234
problem.contrastScalefactors = [1]
231235
problem.contrastBackgroundParams = [[1]]
232236
problem.contrastBackgroundActions = [BackgroundActions.Add]
237+
problem.contrastBackgroundTypes = ["constant"]
233238
problem.contrastResolutionParams = [1]
234239
problem.contrastCustomFiles = [1]
235240
problem.contrastDomainRatios = [0]
@@ -252,6 +257,9 @@ def custom_xy_problem():
252257
[6.2e-06, 6.35e-06],
253258
[0.01, 0.05],
254259
]
260+
problem.customFiles = FileHandles(
261+
[RATapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")]
262+
)
255263

256264
return problem
257265

@@ -552,81 +560,98 @@ def test_make_problem(test_project, test_problem, test_check, request) -> None:
552560
check_problem_equal(problem, test_problem)
553561

554562

555-
@pytest.mark.parametrize(
556-
"test_problem",
557-
[
558-
"standard_layers_problem",
559-
"custom_xy_problem",
560-
"domains_problem",
561-
],
562-
)
563-
def test_check_indices(test_problem, request) -> None:
564-
"""The check_indices routine should not raise errors for a properly defined ProblemDefinition object."""
565-
test_problem = request.getfixturevalue(test_problem)
566-
567-
check_indices(test_problem)
568-
569-
570-
@pytest.mark.parametrize(
571-
["test_problem", "index_list", "bad_value"],
572-
[
573-
("standard_layers_problem", "contrastBulkIns", [0.0]),
574-
("standard_layers_problem", "contrastBulkIns", [2.0]),
575-
("standard_layers_problem", "contrastBulkOuts", [0.0]),
576-
("standard_layers_problem", "contrastBulkOuts", [2.0]),
577-
("standard_layers_problem", "contrastScalefactors", [0.0]),
578-
("standard_layers_problem", "contrastScalefactors", [2.0]),
579-
# ("standard_layers_problem", "contrastBackgroundParams", [0.0]),
580-
# ("standard_layers_problem", "contrastBackgroundParams", [2.0]),
581-
("standard_layers_problem", "contrastResolutionParams", [0.0]),
582-
("standard_layers_problem", "contrastResolutionParams", [2.0]),
583-
("custom_xy_problem", "contrastBulkIns", [0.0]),
584-
("custom_xy_problem", "contrastBulkIns", [2.0]),
585-
("custom_xy_problem", "contrastBulkOuts", [0.0]),
586-
("custom_xy_problem", "contrastBulkOuts", [2.0]),
587-
("custom_xy_problem", "contrastScalefactors", [0.0]),
588-
("custom_xy_problem", "contrastScalefactors", [2.0]),
589-
# ("custom_xy_problem", "contrastBackgroundParams", [0.0]),
590-
# ("custom_xy_problem", "contrastBackgroundParams", [2.0]),
591-
("custom_xy_problem", "contrastResolutionParams", [0.0]),
592-
("custom_xy_problem", "contrastResolutionParams", [2.0]),
593-
("domains_problem", "contrastBulkIns", [0.0]),
594-
("domains_problem", "contrastBulkIns", [2.0]),
595-
("domains_problem", "contrastBulkOuts", [0.0]),
596-
("domains_problem", "contrastBulkOuts", [2.0]),
597-
("domains_problem", "contrastScalefactors", [0.0]),
598-
("domains_problem", "contrastScalefactors", [2.0]),
599-
("domains_problem", "contrastDomainRatios", [0.0]),
600-
("domains_problem", "contrastDomainRatios", [2.0]),
601-
# ("domains_problem", "contrastBackgroundParams", [0.0]),
602-
# ("domains_problem", "contrastBackgroundParams", [2.0]),
603-
("domains_problem", "contrastResolutionParams", [0.0]),
604-
("domains_problem", "contrastResolutionParams", [2.0]),
605-
],
606-
)
607-
def test_check_indices_error(test_problem, index_list, bad_value, request) -> None:
608-
"""The check_indices routine should raise an IndexError if a contrast list contains an index that is out of the
609-
range of the corresponding parameter list in a ProblemDefinition object.
610-
"""
611-
param_list = {
612-
"contrastBulkIns": "bulkIns",
613-
"contrastBulkOuts": "bulkOuts",
614-
"contrastScalefactors": "scalefactors",
615-
"contrastDomainRatios": "domainRatios",
616-
"contrastBackgroundParams": "backgroundParams",
617-
"contrastResolutionParams": "resolutionParams",
618-
}
563+
@pytest.mark.parametrize("test_problem", ["standard_layers_problem", "custom_xy_problem", "domains_problem"])
564+
class TestCheckIndices:
565+
"""Tests for check_indices over a set of three test problems."""
619566

620-
test_problem = request.getfixturevalue(test_problem)
621-
setattr(test_problem, index_list, bad_value)
567+
def test_check_indices(self, test_problem, request) -> None:
568+
"""The check_indices routine should not raise errors for a properly defined ProblemDefinition object."""
569+
test_problem = request.getfixturevalue(test_problem)
622570

623-
with pytest.raises(
624-
IndexError,
625-
match=f'The problem field "{index_list}" contains: {bad_value[0]}, which lie '
626-
f'outside of the range of "{param_list[index_list]}"',
627-
):
628571
check_indices(test_problem)
629572

573+
@pytest.mark.parametrize(
574+
"index_list",
575+
[
576+
"contrastBulkIns",
577+
"contrastBulkOuts",
578+
"contrastScalefactors",
579+
"contrastDomainRatios",
580+
"contrastResolutionParams",
581+
],
582+
)
583+
@pytest.mark.parametrize("bad_value", ([0.0], [2.0]))
584+
def test_check_indices_error(self, test_problem, index_list, bad_value, request) -> None:
585+
"""The check_indices routine should raise an IndexError if a contrast list contains an index that is out of the
586+
range of the corresponding parameter list in a ProblemDefinition object.
587+
"""
588+
param_list = {
589+
"contrastBulkIns": "bulkIns",
590+
"contrastBulkOuts": "bulkOuts",
591+
"contrastScalefactors": "scalefactors",
592+
"contrastDomainRatios": "domainRatios",
593+
"contrastResolutionParams": "resolutionParams",
594+
}
595+
if (test_problem != "domains_problem") and (index_list == "contrastDomainRatios"):
596+
# we expect this to not raise an error for non-domains problems as domainRatios is empty
597+
pytest.xfail()
598+
599+
test_problem = request.getfixturevalue(test_problem)
600+
setattr(test_problem, index_list, bad_value)
601+
602+
with pytest.raises(
603+
IndexError,
604+
match=f'The problem field "{index_list}" contains: {bad_value[0]}, which lies '
605+
f'outside of the range of "{param_list[index_list]}"',
606+
):
607+
check_indices(test_problem)
608+
609+
@pytest.mark.parametrize("background_type", ["constant", "data", "function"])
610+
@pytest.mark.parametrize("bad_value", ([[0.0]], [[2.0]]))
611+
def test_background_params_source_indices(self, test_problem, background_type, bad_value, request):
612+
"""check_indices should raise an IndexError for bad sources in the nested list contrastBackgroundParams."""
613+
test_problem = request.getfixturevalue(test_problem)
614+
test_problem.contrastBackgroundParams = bad_value
615+
test_problem.contrastBackgroundTypes = [background_type]
616+
617+
source_param_lists = {
618+
"constant": "backgroundParams",
619+
"data": "data",
620+
"function": "customFiles",
621+
}
622+
623+
with pytest.raises(
624+
IndexError,
625+
match=f'Entry 0 of contrastBackgroundParams has type "{background_type}" '
626+
f"and source index {bad_value[0][0]}, "
627+
f'which is outside the range of "{source_param_lists[background_type]}".',
628+
):
629+
check_indices(test_problem)
630+
631+
@pytest.mark.parametrize(
632+
"bad_value",
633+
(
634+
[[1.0, 0.0]],
635+
[[1.0, 2.0]],
636+
[[1.0, 1.0, 2.0]],
637+
[[1.0], [1.0, 0.0]],
638+
),
639+
)
640+
def test_background_params_value_indices(self, test_problem, bad_value, request):
641+
"""check_indices should raise an IndexError for bad values in the nested list contrastBackgroundParams."""
642+
test_problem = request.getfixturevalue(test_problem)
643+
test_problem.contrastBackgroundParams = bad_value
644+
645+
if len(bad_value) > 1:
646+
test_problem.contrastBackgroundTypes.append("constant")
647+
648+
with pytest.raises(
649+
IndexError,
650+
match=f"Entry {len(bad_value)-1} of contrastBackgroundParams contains: {bad_value[-1][-1]}"
651+
f', which lies outside of the range of "backgroundParams"',
652+
):
653+
check_indices(test_problem)
654+
630655

631656
def test_get_python_handle():
632657
path = pathlib.Path(__file__).parent.resolve()

0 commit comments

Comments
 (0)