Skip to content

Commit c8b4f77

Browse files
authored
Updates cpp with resolution and background functions (#120)
1 parent 9905295 commit c8b4f77

File tree

13 files changed

+186
-113
lines changed

13 files changed

+186
-113
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
3+
4+
def backgroundFunction(xdata, params):
5+
# Split up the params array
6+
Ao = params[0]
7+
k = params[1]
8+
back_const = params[2]
9+
10+
# Make an exponential decay background
11+
background = np.zeros(len(xdata))
12+
for i in range(0, len(xdata)):
13+
background[i] = Ao * np.exp(-k * xdata[i]) + back_const
14+
15+
return background

RATapi/inputs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
223223
data_limits = []
224224
simulation_limits = []
225225
contrast_resolution_params = []
226+
contrast_resolution_types = []
226227

227228
# set data, background and resolution for each contrast
228229
for contrast in project.contrasts:
@@ -278,11 +279,12 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
278279
all_data.append(np.column_stack((data, np.zeros((data.shape[0], 6 - data.shape[1])))))
279280

280281
# Set resolution parameters, with -1 used to indicate a data resolution
282+
contrast_resolution_param = []
281283
resolution = project.resolutions[contrast.resolution]
282-
if resolution.type == TypeOptions.Data:
283-
contrast_resolution_params.append(-1)
284-
else:
285-
contrast_resolution_params.append(project.resolution_parameters.index(resolution.source, True))
284+
contrast_resolution_types.append(resolution.type)
285+
if resolution.source:
286+
contrast_resolution_param.append(project.resolution_parameters.index(resolution.source, True))
287+
contrast_resolution_params.append(contrast_resolution_param)
286288

287289
problem = ProblemDefinition()
288290

@@ -306,7 +308,10 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
306308
]
307309
problem.contrastBulkIns = [project.bulk_in.index(contrast.bulk_in, True) for contrast in project.contrasts]
308310
problem.contrastBulkOuts = [project.bulk_out.index(contrast.bulk_out, True) for contrast in project.contrasts]
311+
309312
problem.contrastResolutionParams = contrast_resolution_params
313+
problem.contrastResolutionTypes = contrast_resolution_types
314+
310315
problem.backgroundParams = [param.value for param in project.background_parameters]
311316
problem.qzshifts = [0.0]
312317
problem.scalefactors = [param.value for param in project.scalefactors]
@@ -429,7 +434,6 @@ def check_indices(problem: ProblemDefinition) -> None:
429434
"scalefactors": "contrastScalefactors",
430435
"bulkIns": "contrastBulkIns",
431436
"bulkOuts": "contrastBulkOuts",
432-
"resolutionParams": "contrastResolutionParams",
433437
"domainRatios": "contrastDomainRatios",
434438
}
435439

RATapi/outputs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class ContrastParams(RATResult):
7171
scalefactors: np.ndarray
7272
bulkIn: np.ndarray
7373
bulkOut: np.ndarray
74-
resolutionParams: np.ndarray
7574
subRoughs: np.ndarray
7675
resample: np.ndarray
7776

@@ -178,7 +177,6 @@ def make_results(
178177
scalefactors=output_results.contrastParams.scalefactors,
179178
bulkIn=output_results.contrastParams.bulkIn,
180179
bulkOut=output_results.contrastParams.bulkOut,
181-
resolutionParams=output_results.contrastParams.resolutionParams,
182180
subRoughs=output_results.contrastParams.subRoughs,
183181
resample=output_results.contrastParams.resample,
184182
)

RATapi/wrappers.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,24 @@ def getHandle(self) -> Callable[[ArrayLike, ArrayLike, ArrayLike, int, int], tup
5757
5858
"""
5959

60-
def handle(params, bulk_in, bulk_out, contrast, domain=-1):
61-
if domain == -1:
62-
output, sub_rough = getattr(self.engine, self.function_name)(
63-
np.array(params, "float"),
64-
np.array(bulk_in, "float"),
65-
np.array(bulk_out, "float"),
66-
float(contrast + 1),
67-
nargout=2,
60+
def handle(*args):
61+
if len(args) == 2:
62+
output = getattr(self.engine, self.function_name)(
63+
np.array(args[0], "float"), # xdata
64+
np.array(args[1], "float"), # params
65+
nargout=1,
6866
)
67+
return np.array(output, "float").tolist()
6968
else:
7069
output, sub_rough = getattr(self.engine, self.function_name)(
71-
np.array(params, "float"),
72-
np.array(bulk_in, "float"),
73-
np.array(bulk_out, "float"),
74-
float(contrast + 1),
75-
float(domain + 1),
70+
np.array(args[0], "float"), # params
71+
np.array(args[1], "float"), # bulk in
72+
np.array(args[2], "float"), # bulk out
73+
float(args[3] + 1), # contrast
74+
float(-1 if len(args) < 5 else args[4] + 1), # domain
7675
nargout=2,
7776
)
78-
return output, sub_rough
77+
return np.array(output, "float").tolist(), float(sub_rough)
7978

8079
return handle
8180

@@ -105,11 +104,7 @@ def getHandle(self) -> Callable[[ArrayLike, ArrayLike, ArrayLike, int, int], tup
105104
106105
"""
107106

108-
def handle(params, bulk_in, bulk_out, contrast, domain=-1):
109-
if domain == -1:
110-
output, sub_rough = self.engine.invoke(params, bulk_in, bulk_out, contrast)
111-
else:
112-
output, sub_rough = self.engine.invoke(params, bulk_in, bulk_out, contrast, domain)
113-
return output, sub_rough
107+
def handle(*args):
108+
return self.engine.invoke(*args)
114109

115110
return handle

cpp/RAT

Submodule RAT updated 94 files

0 commit comments

Comments
 (0)