Skip to content

Commit 558ad61

Browse files
authored
Constructs python outputs and adds file "run.py" (#35)
* Adds file "run.py" * Adds file "outputs.py", which constructs python results objects in "run.py" * Adds test "test_outputs.py" * Rewrites "plot_ref_sld" to use python objects * Adds test cases to "test_outputs.py" * Reorganises output tests and fixtures into "test_run.py" and "conftest.py" * Fixes bug for "data" input to "project.py" and "test_inputs.py" * Updates submodule * Removes "bestFitsMean" from "bayesResults"
1 parent 62f1281 commit 558ad61

File tree

15 files changed

+2529
-53
lines changed

15 files changed

+2529
-53
lines changed

RAT/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from RAT.classlist import ClassList
33
from RAT.project import Project
44
from RAT.controls import set_controls
5+
from RAT.run import run
56
import RAT.models
67

78
dir_path = os.path.dirname(os.path.realpath(__file__))

RAT/inputs.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def make_problem(project: RAT.Project) -> ProblemDefinition:
133133
problem.contrastBackgroundActions = [action_id[contrast.background_action] for contrast in project.contrasts]
134134
problem.contrastResolutions = [project.resolutions.index(contrast.resolution, True) for contrast in project.contrasts]
135135
problem.contrastCustomFiles = contrast_custom_files
136-
problem.resample = [contrast.resample for contrast in project.contrasts]
137-
problem.dataPresent = [1 if contrast.data else 0 for contrast in project.contrasts]
136+
problem.resample = make_resample(project)
137+
problem.dataPresent = make_data_present(project)
138138
problem.oilChiDataPresent = [0] * len(project.contrasts)
139139
problem.numberOfContrasts = len(project.contrasts)
140140
problem.numberOfLayers = len(project.layers)
@@ -151,6 +151,38 @@ def make_problem(project: RAT.Project) -> ProblemDefinition:
151151
return problem
152152

153153

154+
def make_resample(project: RAT.Project) -> list[int]:
155+
"""Constructs the "resample" field of the problem input required for the compiled RAT code.
156+
157+
Parameters
158+
----------
159+
project : RAT.Project
160+
The project model, which defines the physical system under study.
161+
162+
Returns
163+
-------
164+
: list[int]
165+
The "resample" field of the problem input used in the compiled RAT code.
166+
"""
167+
return [contrast.resample for contrast in project.contrasts]
168+
169+
170+
def make_data_present(project: RAT.Project) -> list[int]:
171+
"""Constructs the "dataPresent" field of the problem input required for the compiled RAT code.
172+
173+
Parameters
174+
----------
175+
project : RAT.Project
176+
The project model, which defines the physical system under study.
177+
178+
Returns
179+
-------
180+
: list[int]
181+
The "dataPresent" field of the problem input used in the compiled RAT code.
182+
"""
183+
return [1 if project.data[project.data.index(contrast.data)].data.size != 0 else 0 for contrast in project.contrasts]
184+
185+
154186
def make_cells(project: RAT.Project) -> Cells:
155187
"""Constructs the cells input required for the compiled RAT code.
156188
@@ -199,14 +231,10 @@ def make_cells(project: RAT.Project) -> Cells:
199231

200232
data_index = project.data.index(contrast.data)
201233

202-
if 'data' in project.data[data_index].model_fields_set:
203-
all_data.append(project.data[data_index].data)
204-
data_limits.append(project.data[data_index].data_range)
205-
simulation_limits.append(project.data[data_index].simulation_range)
206-
else:
207-
all_data.append([0.0, 0.0, 0.0])
208-
data_limits.append([0.0, 0.0])
209-
simulation_limits.append([0.0, 0.0])
234+
all_data.append(project.data[data_index].data)
235+
data_limits.append(project.data[data_index].data_range)
236+
simulation_limits.append(project.data[data_index].simulation_range)
237+
210238

211239
# Populate the set of cells
212240
cells = Cells()

RAT/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def model_post_init(self, __context: Any) -> None:
123123
"""If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
124124
set to the min and max values of the first column (assumed to be q) of the supplied data.
125125
"""
126-
if len(self.data[:, 0]) > 0:
126+
if self.data.shape[0] > 0:
127127
data_min = np.min(self.data[:, 0])
128128
data_max = np.max(self.data[:, 0])
129129
for field in ["data_range", "simulation_range"]:
@@ -135,7 +135,7 @@ def check_ranges(self) -> 'Data':
135135
"""The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
136136
of the "simulation_range" field must lie outside the range of the supplied data.
137137
"""
138-
if len(self.data[:, 0]) > 0:
138+
if self.data.shape[0] > 0:
139139
data_min = np.min(self.data[:, 0])
140140
data_max = np.max(self.data[:, 0])
141141
if "data_range" in self.model_fields_set and (self.data_range[0] < data_min or

RAT/outputs.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""Converts outputs from the compiled RAT code to python dataclasses"""
2+
3+
from dataclasses import dataclass
4+
import numpy as np
5+
from typing import Optional, Union
6+
from RAT.utils.enums import Procedures
7+
import RAT.rat_core
8+
9+
10+
@dataclass
11+
class CalculationResults:
12+
chiValues: np.ndarray
13+
sumChi: float
14+
15+
16+
@dataclass
17+
class ContrastParams:
18+
backgroundParams: np.ndarray
19+
scalefactors: np.ndarray
20+
bulkIn: np.ndarray
21+
bulkOut: np.ndarray
22+
resolutionParams: np.ndarray
23+
subRoughs: np.ndarray
24+
resample: np.ndarray
25+
26+
27+
@dataclass
28+
class Results:
29+
reflectivity: list
30+
simulation: list
31+
shiftedData: list
32+
layerSlds: list
33+
sldProfiles: list
34+
resampledLayers: list
35+
calculationResults: CalculationResults
36+
contrastParams: ContrastParams
37+
fitParams: np.ndarray
38+
fitNames: list[str]
39+
40+
41+
@dataclass
42+
class PredictionIntervals:
43+
reflectivity: list
44+
sld: list
45+
reflectivityXData: list
46+
sldXData: list
47+
sampleChi: np.ndarray
48+
49+
50+
@dataclass
51+
class ConfidenceIntervals:
52+
percentile95: np.ndarray
53+
percentile65: np.ndarray
54+
mean: np.ndarray
55+
56+
57+
@dataclass
58+
class DreamParams:
59+
nParams: float
60+
nChains: float
61+
nGenerations: float
62+
parallel: bool
63+
CPU: float
64+
jumpProbability: float
65+
pUnitGamma: float
66+
nCR: float
67+
delta: float
68+
steps: float
69+
zeta: float
70+
outlier: str
71+
adaptPCR: bool
72+
thinning: float
73+
epsilon: float
74+
ABC: bool
75+
IO: bool
76+
storeOutput: bool
77+
R: np.ndarray
78+
79+
80+
@dataclass
81+
class DreamOutput:
82+
allChains: np.ndarray
83+
outlierChains: np.ndarray
84+
runtime: float
85+
iteration: float
86+
modelOutput: float
87+
AR: np.ndarray
88+
R_stat: np.ndarray
89+
CR: np.ndarray
90+
91+
92+
@dataclass
93+
class NestedSamplerOutput:
94+
logZ: float
95+
nestSamples: np.ndarray
96+
postSamples: np.ndarray
97+
98+
99+
@dataclass
100+
class BayesResults(Results):
101+
predictionIntervals: PredictionIntervals
102+
confidenceIntervals: ConfidenceIntervals
103+
dreamParams: DreamParams
104+
dreamOutput: DreamOutput
105+
nestedSamplerOutput: NestedSamplerOutput
106+
chain: np.ndarray
107+
108+
109+
def make_results(procedure: Procedures, output_results: RAT.rat_core.OutputResult,
110+
bayes_results: Optional[RAT.rat_core.BayesResults] = None) -> Union[Results, BayesResults]:
111+
"""Initialise a python Results or BayesResults object using the outputs from a RAT calculation."""
112+
113+
calculation_results = CalculationResults(chiValues=output_results.calculationResults.chiValues,
114+
sumChi=output_results.calculationResults.sumChi
115+
)
116+
contrast_params = ContrastParams(
117+
backgroundParams=output_results.contrastParams.backgroundParams,
118+
scalefactors=output_results.contrastParams.scalefactors,
119+
bulkIn=output_results.contrastParams.bulkIn,
120+
bulkOut=output_results.contrastParams.bulkOut,
121+
resolutionParams=output_results.contrastParams.resolutionParams,
122+
subRoughs=output_results.contrastParams.subRoughs,
123+
resample=output_results.contrastParams.resample
124+
)
125+
126+
if procedure in [Procedures.NS, Procedures.Dream]:
127+
128+
prediction_intervals = PredictionIntervals(
129+
reflectivity=bayes_results.predictionIntervals.reflectivity,
130+
sld=bayes_results.predictionIntervals.sld,
131+
reflectivityXData=bayes_results.predictionIntervals.reflectivityXData,
132+
sldXData=bayes_results.predictionIntervals.sldXData,
133+
sampleChi=bayes_results.predictionIntervals.sampleChi
134+
)
135+
136+
confidence_intervals = ConfidenceIntervals(
137+
percentile95=bayes_results.confidenceIntervals.percentile95,
138+
percentile65=bayes_results.confidenceIntervals.percentile65,
139+
mean=bayes_results.confidenceIntervals.mean
140+
)
141+
142+
dream_params = DreamParams(
143+
nParams=bayes_results.dreamParams.nParams,
144+
nChains=bayes_results.dreamParams.nChains,
145+
nGenerations=bayes_results.dreamParams.nGenerations,
146+
parallel=bool(bayes_results.dreamParams.parallel),
147+
CPU=bayes_results.dreamParams.CPU,
148+
jumpProbability=bayes_results.dreamParams.jumpProbability,
149+
pUnitGamma=bayes_results.dreamParams.pUnitGamma,
150+
nCR=bayes_results.dreamParams.nCR,
151+
delta=bayes_results.dreamParams.delta,
152+
steps=bayes_results.dreamParams.steps,
153+
zeta=bayes_results.dreamParams.zeta,
154+
outlier=bayes_results.dreamParams.outlier,
155+
adaptPCR=bool(bayes_results.dreamParams.adaptPCR),
156+
thinning=bayes_results.dreamParams.thinning,
157+
epsilon=bayes_results.dreamParams.epsilon,
158+
ABC=bool(bayes_results.dreamParams.ABC),
159+
IO=bool(bayes_results.dreamParams.IO),
160+
storeOutput=bool(bayes_results.dreamParams.storeOutput),
161+
R=bayes_results.dreamParams.R
162+
)
163+
164+
dream_output = DreamOutput(
165+
allChains=bayes_results.dreamOutput.allChains,
166+
outlierChains=bayes_results.dreamOutput.outlierChains,
167+
runtime=bayes_results.dreamOutput.runtime,
168+
iteration=bayes_results.dreamOutput.iteration,
169+
modelOutput=bayes_results.dreamOutput.modelOutput,
170+
AR=bayes_results.dreamOutput.AR,
171+
R_stat=bayes_results.dreamOutput.R_stat,
172+
CR=bayes_results.dreamOutput.CR
173+
)
174+
175+
nested_sampler_output = NestedSamplerOutput(
176+
logZ=bayes_results.nestedSamplerOutput.logZ,
177+
nestSamples=bayes_results.nestedSamplerOutput.nestSamples,
178+
postSamples=bayes_results.nestedSamplerOutput.postSamples
179+
)
180+
181+
results = BayesResults(
182+
reflectivity=output_results.reflectivity,
183+
simulation=output_results.simulation,
184+
shiftedData=output_results.shiftedData,
185+
layerSlds=output_results.layerSlds,
186+
sldProfiles=output_results.sldProfiles,
187+
resampledLayers=output_results.resampledLayers,
188+
calculationResults=calculation_results,
189+
contrastParams=contrast_params,
190+
fitParams=output_results.fitParams,
191+
fitNames=output_results.fitNames,
192+
predictionIntervals=prediction_intervals,
193+
confidenceIntervals=confidence_intervals,
194+
dreamParams=dream_params,
195+
dreamOutput=dream_output,
196+
nestedSamplerOutput=nested_sampler_output,
197+
chain=bayes_results.chain
198+
)
199+
200+
else:
201+
202+
results = Results(
203+
reflectivity=output_results.reflectivity,
204+
simulation=output_results.simulation,
205+
shiftedData=output_results.shiftedData,
206+
layerSlds=output_results.layerSlds,
207+
sldProfiles=output_results.sldProfiles,
208+
resampledLayers=output_results.resampledLayers,
209+
calculationResults=calculation_results,
210+
contrastParams=contrast_params,
211+
fitParams=output_results.fitParams,
212+
fitNames=output_results.fitNames
213+
)
214+
215+
return results

RAT/project.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
124124
value_1='Resolution Param 1'))
125125

126126
custom_files: ClassList = ClassList()
127-
data: ClassList = ClassList(RAT.models.Data(name='Simulation'))
127+
data: ClassList = ClassList()
128128
layers: ClassList = ClassList()
129129
domain_contrasts: ClassList = ClassList()
130130
contrasts: ClassList = ClassList()
@@ -187,6 +187,9 @@ def model_post_init(self, __context: Any) -> None:
187187
self.parameters.remove('Substrate Roughness')
188188
self.parameters.insert(0, RAT.models.ProtectedParameter(**substrate_roughness_values))
189189

190+
if 'Simulation' not in self.data.get_names():
191+
self.data.insert(0, RAT.models.Data(name='Simulation'))
192+
190193
self._all_names = self.get_all_names()
191194
self._contrast_model_field = self.get_contrast_model_field()
192195
self._protected_parameters = self.get_all_protected_parameters()

RAT/run.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from RAT.inputs import make_input
2+
from RAT.outputs import make_results
3+
import RAT.rat_core
4+
5+
6+
def run(project, controls):
7+
"""Run RAT for the given project and controls inputs."""
8+
9+
parameter_field = {'parameters': 'params',
10+
'bulk_in': 'bulkIn',
11+
'bulk_out': 'bulkOut',
12+
'scalefactors': 'scalefactors',
13+
'domain_ratios': 'domainRatio',
14+
'background_parameters': 'backgroundParams',
15+
'resolution_parameters': 'resolutionParams',
16+
}
17+
18+
problem_definition, cells, limits, priors, cpp_controls = make_input(project, controls)
19+
20+
problem_definition, output_results, bayes_results = RAT.rat_core.RATMain(problem_definition, cells, limits,
21+
cpp_controls, priors)
22+
23+
results = RAT.outputs.make_results(controls.procedure, output_results, bayes_results)
24+
25+
# Update parameter values in project
26+
for class_list in RAT.project.parameter_class_lists:
27+
for (index, value) in enumerate(getattr(problem_definition, parameter_field[class_list])):
28+
setattr(getattr(project, class_list)[index], 'value', value)
29+
30+
return project, results

0 commit comments

Comments
 (0)