Skip to content

Commit e148415

Browse files
authored
Removes priors struct, adding relevant fields to the project (RascalSoftware#135)
* Removes priors struct, adding relevant fields to the project * Addresses review comments
1 parent 148d0d9 commit e148415

File tree

6 files changed

+219
-331
lines changed

6 files changed

+219
-331
lines changed

RATapi/examples/normal_reflectivity/background_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33

4-
def backgroundFunction(xdata, params):
4+
def background_function(xdata, params):
55
# Split up the params array
66
Ao = params[0]
77
k = params[1]

RATapi/inputs.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,19 @@
1010
import RATapi
1111
import RATapi.controls
1212
import RATapi.wrappers
13-
from RATapi.rat_core import Checks, Control, Limits, NameStore, Priors, ProblemDefinition
13+
from RATapi.rat_core import Checks, Control, Limits, NameStore, ProblemDefinition
1414
from RATapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions
1515

16+
parameter_field = {
17+
"parameters": "params",
18+
"bulk_in": "bulkIns",
19+
"bulk_out": "bulkOuts",
20+
"scalefactors": "scalefactors",
21+
"domain_ratios": "domainRatios",
22+
"background_parameters": "backgroundParams",
23+
"resolution_parameters": "resolutionParams",
24+
}
25+
1626

1727
def get_python_handle(file_name: str, function_name: str, path: Union[str, pathlib.Path] = "") -> Callable:
1828
"""Get the function handle from a function defined in a python module located anywhere within the filesystem.
@@ -94,7 +104,7 @@ def __len__(self):
94104
return len(self.files)
95105

96106

97-
def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[ProblemDefinition, Limits, Priors, Control]:
107+
def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[ProblemDefinition, Limits, Control]:
98108
"""Constructs the inputs required for the compiled RAT code using the data defined in the input project and
99109
controls.
100110
@@ -111,65 +121,32 @@ def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[Prob
111121
The problem input used in the compiled RAT code.
112122
limits : RAT.rat_core.Limits
113123
A list of min/max values for each parameter defined in the project.
114-
priors : RAT.rat_core.Priors
115-
The priors defined for each parameter in the project.
116124
cpp_controls : RAT.rat_core.Control
117125
The controls object used in the compiled RAT code.
118126
119127
"""
120-
parameter_field = {
121-
"parameters": "params",
122-
"bulk_in": "bulkIns",
123-
"bulk_out": "bulkOuts",
124-
"scalefactors": "scalefactors",
125-
"domain_ratios": "domainRatios",
126-
"background_parameters": "backgroundParams",
127-
"resolution_parameters": "resolutionParams",
128-
}
129-
130-
prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}
131-
132-
checks = Checks()
133128
limits = Limits()
134-
priors = Priors()
135129

136130
for class_list in RATapi.project.parameter_class_lists:
137-
setattr(checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)])
138131
setattr(
139132
limits,
140133
parameter_field[class_list],
141134
[[element.min, element.max] for element in getattr(project, class_list)],
142135
)
143-
setattr(
144-
priors,
145-
parameter_field[class_list],
146-
[[element.name, element.prior_type, element.mu, element.sigma] for element in getattr(project, class_list)],
147-
)
148136

149-
# Use dummy values for qzshifts
150-
checks.qzshifts = []
137+
# Use dummy value for qzshifts
151138
limits.qzshifts = []
152-
priors.qzshifts = []
153-
154-
priors.priorNames = [
155-
param.name for class_list in RATapi.project.parameter_class_lists for param in getattr(project, class_list)
156-
]
157-
priors.priorValues = [
158-
[prior_id[param.prior_type], param.mu, param.sigma]
159-
for class_list in RATapi.project.parameter_class_lists
160-
for param in getattr(project, class_list)
161-
]
162139

163140
if project.model == LayerModels.CustomXY:
164141
controls.calcSldDuringFit = True
165142

166-
problem = make_problem(project, checks)
143+
problem = make_problem(project)
167144
cpp_controls = make_controls(controls)
168145

169-
return problem, limits, priors, cpp_controls
146+
return problem, limits, cpp_controls
170147

171148

172-
def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
149+
def make_problem(project: RATapi.Project) -> ProblemDefinition:
173150
"""Constructs the problem input required for the compiled RAT code.
174151
175152
Parameters
@@ -184,6 +161,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
184161
185162
"""
186163
hydrate_id = {"bulk in": 1, "bulk out": 2}
164+
prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}
187165

188166
# Set contrast parameters according to model type
189167
if project.model == LayerModels.StandardLayers:
@@ -384,21 +362,32 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
384362
for param in getattr(project, class_list)
385363
if not param.fit
386364
]
365+
problem.priorNames = [
366+
param.name for class_list in RATapi.project.parameter_class_lists for param in getattr(project, class_list)
367+
]
368+
problem.priorValues = [
369+
[prior_id[param.prior_type], param.mu, param.sigma]
370+
for class_list in RATapi.project.parameter_class_lists
371+
for param in getattr(project, class_list)
372+
]
387373

388374
# Names
389375
problem.names = NameStore()
390-
problem.names.params = [param.name for param in project.parameters]
391-
problem.names.backgroundParams = [param.name for param in project.background_parameters]
392-
problem.names.scalefactors = [param.name for param in project.scalefactors]
393-
problem.names.qzshifts = [] # Placeholder for qzshifts
394-
problem.names.bulkIns = [param.name for param in project.bulk_in]
395-
problem.names.bulkOuts = [param.name for param in project.bulk_out]
396-
problem.names.resolutionParams = [param.name for param in project.resolution_parameters]
397-
problem.names.domainRatios = [param.name for param in project.domain_ratios]
376+
for class_list in RATapi.project.parameter_class_lists:
377+
setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)])
398378
problem.names.contrasts = [contrast.name for contrast in project.contrasts]
399379

400380
# Checks
401-
problem.checks = checks
381+
problem.checks = Checks()
382+
for class_list in RATapi.project.parameter_class_lists:
383+
setattr(
384+
problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)]
385+
)
386+
387+
# Use dummy values for qz shifts
388+
problem.names.qzshifts = []
389+
problem.checks.qzshifts = []
390+
402391
check_indices(problem)
403392

404393
return problem

RATapi/run.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def run(project, controls):
104104

105105
horizontal_line = "\u2500" * 107 + "\n"
106106
display_on = controls.display != Display.Off
107-
problem_definition, limits, priors, cpp_controls = make_input(project, controls)
107+
problem_definition, limits, cpp_controls = make_input(project, controls)
108108

109109
if display_on:
110110
print("Starting RAT " + horizontal_line)
@@ -115,7 +115,6 @@ def run(project, controls):
115115
problem_definition,
116116
limits,
117117
cpp_controls,
118-
priors,
119118
)
120119
end = time.time()
121120

cpp/RAT

Submodule RAT updated 115 files

0 commit comments

Comments
 (0)