Skip to content

Commit b38bd3a

Browse files
authored
Adds and refines printed output to users (#50)
* Adds text to run routine, including elapsed time * Adds routine to make sure background and resolution values are checked based on type * Adds __str__ methods to output classes * Hides input in custom errors * Refines and tests __str__ methods for output classes * Addresses review comments * Removes unused function
1 parent 7dcfc1d commit b38bd3a

File tree

9 files changed

+550
-151
lines changed

9 files changed

+550
-151
lines changed

RATapi/controls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle
9191
f'{", ".join(fields.get("procedure", []))}\n',
9292
}
9393
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
94-
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
94+
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
9595

9696
if isinstance(model_input, validated_self.__class__):
9797
# This is for changing fields in a defined model

RATapi/examples/domains/alloy_domains.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def alloy_domains(params, bulkIn, bulkOut, contrast, domain):
1818
gold = [goldThick, goldSLD, goldRough]
1919

2020
# Make the model depending on which domain we are looking at
21-
if domain == 1:
21+
if domain == 0:
2222
output = [alloyUp, gold]
2323
else:
2424
output = [alloyDn, gold]

RATapi/examples/domains/domains_XY_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain):
3333
oxSLD = vfOxide * 3.41e-6
3434

3535
# Layer SLD depends on whether we are calculating the domain or not
36-
if domain == 1:
36+
if domain == 0:
3737
laySLD = vfLayer * layerSLD
3838
else:
3939
laySLD = vfLayer * domainSLD

RATapi/outputs.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,73 @@
11
"""Converts outputs from the compiled RAT code to python dataclasses"""
22

33
from dataclasses import dataclass
4-
from typing import Optional, Union
4+
from typing import Any, Optional, Union
55

66
import numpy as np
77

88
import RATapi.rat_core
99
from RATapi.utils.enums import Procedures
1010

1111

12+
def get_field_string(field: str, value: Any, array_limit: int):
13+
"""Returns a string representation of class fields where large and multidimensional arrays are represented by their
14+
shape.
15+
16+
Parameters
17+
----------
18+
field : str
19+
The name of the field in the RAT output class.
20+
value : Any
21+
The value of the given field in the RAT output class.
22+
array_limit : int
23+
The maximum length of 1D arrays which will be fully displayed.
24+
25+
Returns
26+
-------
27+
field_string : str
28+
The string representation of the field in the RAT output class.
29+
"""
30+
array_text = "Data array: "
31+
if isinstance(value, list) and len(value) > 0:
32+
if isinstance(value[0], np.ndarray):
33+
array_strings = [f"{array_text}[{' x '.join(str(i) for i in array.shape)}]" for array in value]
34+
field_string = f"{field} = [{', '.join(str(string) for string in array_strings)}],\n"
35+
elif isinstance(value[0], list) and len(value[0]) > 0 and isinstance(value[0][0], np.ndarray):
36+
array_strings = [
37+
[f"{array_text}[{' x '.join(str(i) for i in array.shape)}]" for array in sub_list] for sub_list in value
38+
]
39+
list_strings = [f"[{', '.join(string for string in list_string)}]" for list_string in array_strings]
40+
field_string = f"{field} = [{', '.join(list_strings)}],\n"
41+
else:
42+
field_string = f"{field} = {str(value)},\n"
43+
elif isinstance(value, np.ndarray):
44+
if value.ndim == 1 and value.size < array_limit:
45+
field_string = f"{field} = {str(value) if value.size > 0 else '[]'},\n"
46+
else:
47+
field_string = f"{field} = {array_text}[{' x '.join(str(i) for i in value.shape)}],\n"
48+
else:
49+
field_string = f"{field} = {str(value)},\n"
50+
51+
return field_string
52+
53+
54+
class RATResult:
55+
def __str__(self):
56+
output = f"{self.__class__.__name__}(\n"
57+
for key, value in self.__dict__.items():
58+
output += "\t" + get_field_string(key, value, 100)
59+
output += ")"
60+
return output
61+
62+
1263
@dataclass
13-
class CalculationResults:
64+
class CalculationResults(RATResult):
1465
chiValues: np.ndarray
1566
sumChi: float
1667

1768

1869
@dataclass
19-
class ContrastParams:
70+
class ContrastParams(RATResult):
2071
backgroundParams: np.ndarray
2172
scalefactors: np.ndarray
2273
bulkIn: np.ndarray
@@ -39,9 +90,15 @@ class Results:
3990
fitParams: np.ndarray
4091
fitNames: list[str]
4192

93+
def __str__(self):
94+
output = ""
95+
for key, value in self.__dict__.items():
96+
output += get_field_string(key, value, 100)
97+
return output
98+
4299

43100
@dataclass
44-
class PredictionIntervals:
101+
class PredictionIntervals(RATResult):
45102
reflectivity: list
46103
sld: list
47104
reflectivityXData: list
@@ -50,14 +107,14 @@ class PredictionIntervals:
50107

51108

52109
@dataclass
53-
class ConfidenceIntervals:
110+
class ConfidenceIntervals(RATResult):
54111
percentile95: np.ndarray
55112
percentile65: np.ndarray
56113
mean: np.ndarray
57114

58115

59116
@dataclass
60-
class DreamParams:
117+
class DreamParams(RATResult):
61118
nParams: float
62119
nChains: float
63120
nGenerations: float
@@ -80,7 +137,7 @@ class DreamParams:
80137

81138

82139
@dataclass
83-
class DreamOutput:
140+
class DreamOutput(RATResult):
84141
allChains: np.ndarray
85142
outlierChains: np.ndarray
86143
runtime: float
@@ -92,7 +149,7 @@ class DreamOutput:
92149

93150

94151
@dataclass
95-
class NestedSamplerOutput:
152+
class NestedSamplerOutput(RATResult):
96153
logZ: float
97154
nestSamples: np.ndarray
98155
postSamples: np.ndarray

RATapi/project.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,26 @@
3333
}
3434

3535
values_defined_in = {
36-
"backgrounds.value_1": "background_parameters",
37-
"backgrounds.value_2": "background_parameters",
38-
"backgrounds.value_3": "background_parameters",
39-
"backgrounds.value_4": "background_parameters",
40-
"backgrounds.value_5": "background_parameters",
41-
"resolutions.value_1": "resolution_parameters",
42-
"resolutions.value_2": "resolution_parameters",
43-
"resolutions.value_3": "resolution_parameters",
44-
"resolutions.value_4": "resolution_parameters",
45-
"resolutions.value_5": "resolution_parameters",
36+
"backgrounds.constant.value_1": "background_parameters",
37+
"backgrounds.constant.value_2": "background_parameters",
38+
"backgrounds.constant.value_3": "background_parameters",
39+
"backgrounds.constant.value_4": "background_parameters",
40+
"backgrounds.constant.value_5": "background_parameters",
41+
"backgrounds.data.value_1": "data",
42+
"backgrounds.data.value_2": "data",
43+
"backgrounds.data.value_3": "data",
44+
"backgrounds.data.value_4": "data",
45+
"backgrounds.data.value_5": "data",
46+
"resolutions.constant.value_1": "resolution_parameters",
47+
"resolutions.constant.value_2": "resolution_parameters",
48+
"resolutions.constant.value_3": "resolution_parameters",
49+
"resolutions.constant.value_4": "resolution_parameters",
50+
"resolutions.constant.value_5": "resolution_parameters",
51+
"resolutions.data.value_1": "data",
52+
"resolutions.data.value_2": "data",
53+
"resolutions.data.value_3": "data",
54+
"resolutions.data.value_4": "data",
55+
"resolutions.data.value_5": "data",
4656
"layers.thickness": "parameters",
4757
"layers.SLD": "parameters",
4858
"layers.SLD_real": "parameters",
@@ -434,8 +444,13 @@ def update_renamed_models(self) -> "Project":
434444
def cross_check_model_values(self) -> "Project":
435445
"""Certain model fields should contain values defined elsewhere in the project."""
436446
value_fields = ["value_1", "value_2", "value_3", "value_4", "value_5"]
437-
self.check_allowed_values("backgrounds", value_fields, self.background_parameters.get_names())
438-
self.check_allowed_values("resolutions", value_fields, self.resolution_parameters.get_names())
447+
self.check_allowed_background_resolution_values(
448+
"backgrounds", value_fields, self.background_parameters.get_names(), self.data.get_names()
449+
)
450+
self.check_allowed_background_resolution_values(
451+
"resolutions", value_fields, self.resolution_parameters.get_names(), self.data.get_names()
452+
)
453+
439454
self.check_allowed_values(
440455
"layers",
441456
["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness"],
@@ -526,6 +541,49 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
526541
f'"{values_defined_in[f"{attribute}.{field}"]}".',
527542
)
528543

544+
def check_allowed_background_resolution_values(
545+
self, attribute: str, field_list: list[str], allowed_constants: list[str], allowed_data: list[str]
546+
) -> None:
547+
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
548+
549+
For backgrounds and resolutions, the list of allowed values depends on whether the type of the
550+
background/resolution is "constant" or "data".
551+
552+
Parameters
553+
----------
554+
attribute : str
555+
The attribute of Project being validated.
556+
field_list : list [str]
557+
The fields of the attribute to be checked for valid values.
558+
allowed_constants : list [str]
559+
The list of allowed values for the fields given in field_list if the type is "constant".
560+
allowed_data : list [str]
561+
The list of allowed values for the fields given in field_list if the type is "data".
562+
563+
Raises
564+
------
565+
ValueError
566+
Raised if any field in field_list has a value not specified in allowed_constants or allowed_data as
567+
appropriate.
568+
569+
"""
570+
class_list = getattr(self, attribute)
571+
for model in class_list:
572+
if model.type == TypeOptions.Constant:
573+
allowed_values = allowed_constants
574+
elif model.type == TypeOptions.Data:
575+
allowed_values = allowed_data
576+
else:
577+
raise ValueError('"Function" type backgrounds and resolutions are not yet supported.')
578+
579+
for field in field_list:
580+
value = getattr(model, field, "")
581+
if value and value not in allowed_values:
582+
raise ValueError(
583+
f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
584+
f'"{values_defined_in[f"{attribute}.{model.type}.{field}"]}".',
585+
)
586+
529587
def check_contrast_model_allowed_values(
530588
self,
531589
contrast_attribute: str,
@@ -648,7 +706,7 @@ def wrapped_func(*args, **kwargs):
648706
except ValidationError as exc:
649707
class_list.data = previous_state
650708
custom_error_list = custom_pydantic_validation_error(exc.errors())
651-
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
709+
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
652710
except (TypeError, ValueError):
653711
class_list.data = previous_state
654712
raise

RATapi/run.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import time
2+
13
import RATapi.rat_core
24
from RATapi.inputs import make_input
35
from RATapi.outputs import make_results
6+
from RATapi.utils.enums import Display
47

58

69
def run(project, controls):
@@ -15,15 +18,25 @@ def run(project, controls):
1518
"resolution_parameters": "resolutionParams",
1619
}
1720

21+
horizontal_line = "\u2500" * 107 + "\n"
22+
1823
problem_definition, cells, limits, priors, cpp_controls = make_input(project, controls)
1924

25+
if controls.display != Display.Off:
26+
print("Starting RAT " + horizontal_line)
27+
28+
start = time.time()
2029
problem_definition, output_results, bayes_results = RATapi.rat_core.RATMain(
2130
problem_definition,
2231
cells,
2332
limits,
2433
cpp_controls,
2534
priors,
2635
)
36+
end = time.time()
37+
38+
if controls.display != Display.Off:
39+
print(f"Elapsed time is {end-start:.3f} seconds\n")
2740

2841
results = make_results(controls.procedure, output_results, bayes_results)
2942

@@ -32,4 +45,7 @@ def run(project, controls):
3245
for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])):
3346
getattr(project, class_list)[index].value = value
3447

48+
if controls.display != Display.Off:
49+
print("Finished RAT " + horizontal_line)
50+
3551
return project, results

cpp/rat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ BayesResults bayesResultsFromStruct8T(const RAT::struct8_T results)
11461146
bayesResults.dreamParams.delta = results.dreamParams.delta;
11471147
bayesResults.dreamParams.steps = results.dreamParams.steps;
11481148
bayesResults.dreamParams.zeta = results.dreamParams.zeta;
1149-
bayesResults.dreamParams.outlier = std::string(results.dreamParams.outlier);
1149+
bayesResults.dreamParams.outlier = std::string(results.dreamParams.outlier, 3);
11501150
bayesResults.dreamParams.adaptPCR = results.dreamParams.adaptPCR;
11511151
bayesResults.dreamParams.thinning = results.dreamParams.thinning;
11521152
bayesResults.dreamParams.epsilon = results.dreamParams.epsilon;

0 commit comments

Comments
 (0)