Skip to content

Commit e99b285

Browse files
committed
data backgrounds add to contrast data
1 parent 5d9529f commit e99b285

File tree

2 files changed

+71
-26
lines changed

2 files changed

+71
-26
lines changed

RATapi/inputs.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,6 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
214214
else:
215215
contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts]
216216

217-
# Set background parameters, with -1 used to indicate a data background
218-
contrast_background_params = []
219-
contrast_background_types = []
220-
221217
# Get details of defined layers
222218
layer_details = []
223219
for layer in project.layers:
@@ -230,14 +226,42 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
230226

231227
layer_details.append(layer_params)
232228

229+
contrast_background_params = []
230+
contrast_background_types = []
231+
all_data = []
232+
data_limits = []
233+
simulation_limits = []
234+
contrast_resolution_params = []
235+
236+
# set data, background and resolution for each contrast
233237
for contrast in project.contrasts:
238+
# set data
239+
data_index = project.data.index(contrast.data)
240+
data = project.data[data_index].data
241+
data_range = project.data[data_index].data_range
242+
simulation_range = project.data[data_index].simulation_range
243+
244+
if data_range:
245+
data_limits.append(data_range)
246+
else:
247+
data_limits.append([0.0, 0.0])
248+
249+
if simulation_range:
250+
simulation_limits.append(simulation_range)
251+
else:
252+
simulation_limits.append([0.0, 0.0])
253+
254+
# set background parameters
234255
background = project.backgrounds[contrast.background]
235256
contrast_background_types.append(background.type)
236257
contrast_background_param = []
237258
if background.type == TypeOptions.Data:
238259
contrast_background_param.append(project.data.index(background.source, True))
239260
if background.value_1 != "":
240261
contrast_background_param.append(project.background_parameters.index(background.value_1))
262+
# if we are using a data background, we add the background data to the contrast data
263+
data = append_data_background(data, project.data[background.source].data)
264+
241265
elif background.type == TypeOptions.Function:
242266
contrast_background_param.append(project.custom_files.index(background.source, True))
243267
contrast_background_param.extend(
@@ -259,35 +283,16 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
259283

260284
contrast_background_params.append(contrast_background_param)
261285

262-
# Set resolution parameters, with -1 used to indicate a data resolution
263-
all_data = []
264-
data_limits = []
265-
simulation_limits = []
266-
contrast_resolution_params = []
286+
# contrast data has exactly six columns to include background data if relevant
287+
all_data.append(np.column_stack((data, np.zeros((data.shape[0], 6 - data.shape[1])))))
267288

268-
for contrast in project.contrasts:
289+
# Set resolution parameters, with -1 used to indicate a data resolution
269290
resolution = project.resolutions[contrast.resolution]
270291
if resolution.type == TypeOptions.Data:
271292
contrast_resolution_params.append(-1)
272293
else:
273294
contrast_resolution_params.append(project.resolution_parameters.index(resolution.source, True))
274295

275-
data_index = project.data.index(contrast.data)
276-
data = project.data[data_index].data
277-
all_data.append(np.column_stack((data, np.zeros((data.shape[0], 6 - data.shape[1])))))
278-
data_range = project.data[data_index].data_range
279-
simulation_range = project.data[data_index].simulation_range
280-
281-
if data_range:
282-
data_limits.append(data_range)
283-
else:
284-
data_limits.append([0.0, 0.0])
285-
286-
if simulation_range:
287-
simulation_limits.append(simulation_range)
288-
else:
289-
simulation_limits.append([0.0, 0.0])
290-
291296
problem = ProblemDefinition()
292297

293298
problem.TF = project.calculation
@@ -487,6 +492,28 @@ def check_indices(problem: ProblemDefinition) -> None:
487492
)
488493

489494

495+
def append_data_background(data: np.array, background: np.array) -> np.array:
496+
"""Add background data to contrast data.
497+
498+
Parameters
499+
----------
500+
data : np.array
501+
The contrast data to which we are adding a background.
502+
background : np.array
503+
The background data to add to the contrast.
504+
505+
Returns
506+
-------
507+
np.array
508+
The contrast data with background data added as two additional columns.
509+
510+
"""
511+
if not np.allclose(data[:, 0], background[:, 0]):
512+
raise ValueError("The q-values of the data and background must be equal.")
513+
514+
return np.hstack((data, background[:, 1:]))
515+
516+
490517
def make_controls(input_controls: RATapi.Controls) -> Control:
491518
"""Converts the controls object to the format required by the compiled RAT code.
492519

tests/test_inputs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,24 @@ def test_background_params_value_indices(self, test_problem, bad_value, request)
653653
check_indices(test_problem)
654654

655655

656+
def test_append_data_background():
657+
"""Test that background data is correctly added to contrast data."""
658+
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
659+
background = np.array([[1, 10, 11], [4, 12, 13], [7, 14, 15]])
660+
661+
result = RATapi.inputs.append_data_background(data, background)
662+
np.testing.assert_allclose(result, np.array([[1, 2, 3, 10, 11], [4, 5, 6, 12, 13], [7, 8, 9, 14, 15]]))
663+
664+
665+
def test_append_data_background_error():
666+
"""Test that append_data_background raises an error if the q-values are not equal."""
667+
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
668+
background = np.array([[56, 10, 11], [41, 12, 13], [7, 14, 15]])
669+
670+
with pytest.raises(ValueError, match=("The q-values of the data and background must be equal.")):
671+
RATapi.inputs.append_data_background(data, background)
672+
673+
656674
def test_get_python_handle():
657675
path = pathlib.Path(__file__).parent.resolve()
658676
assert RATapi.inputs.get_python_handle("utils.py", "dummy_function", path).__code__ == dummy_function.__code__

0 commit comments

Comments
 (0)