Skip to content

Commit f635416

Browse files
authored
Adds code to construct inputs for the compiled code (#25)
* Tidies up configs for pydantic models * Moves enums into "utils/enums.py" * First draft of "inputs.py" code to construct inputs for compiled code * Renames enums * Moves definition of dataclasses to "utils/dataclasses.py" * Tidies up "inputs" module with "make_problem"2" and "make_cells" routines * Adds offset to the "index" routine in "classList.py" * Adds "NaNList" class to "tests/utils.py" * Adds "test_inputs.py", and corrects code in "inputs.py" * Adjusts tests to fit with pybind example * Updates code to ensure tests pass * Renames parameters to match matlab updates * Converts "inputs.py" to use C++ objects directly. * Renames "misc.py" as "wrappers.py" and "Calc" enum as "Calculations" * Adds background actions to the contrast model * Adds file wrappers to "make"cells" * Adds additional examples to "test_inputs.py" to improve test coverage" * Adds code to support recording custom files in "make_cells" * Updates submodule and tidying up * Updates requirements * . . .and "pyproject.toml" * Fixes pydantic to version 2.6.4 * . . . and "setup.py" * Sort out version requirements * Addresses review comments and import statements * Changes parameters from optional to compulsory in "Layer" and "AbsorptionLayer" models * Enables optional hydration in layer models
1 parent b62c5d4 commit f635416

22 files changed

+1595
-627
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ And finally create a separate branch to begin work
1616

1717
git checkout -b new-feature
1818

19+
If there are updates to the C++ RAT submodule, run the following command to update the local branch
20+
21+
git submodule update --remote
22+
1923
Once complete submit a [pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) via GitHub.
2024
Ensure to rebase your branch to include the latest changes on your branch and resolve possible merge conflicts.
2125

RAT/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import os
22
from RAT.classlist import ClassList
33
from RAT.project import Project
4-
import RAT.controls
4+
from RAT.controls import set_controls
55
import RAT.models
66

7-
87
dir_path = os.path.dirname(os.path.realpath(__file__))
98
os.environ["RAT_PATH"] = os.path.join(dir_path, '')

RAT/classlist.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def __repr__(self):
6363
output = repr(self.data)
6464
return output
6565

66-
def __setitem__(self, index: int, item: 'RAT.models') -> None:
66+
def __setitem__(self, index: int, item: object) -> None:
6767
"""Replace the object at an existing index of the ClassList."""
6868
self._setitem(index, item)
6969

70-
def _setitem(self, index: int, item: 'RAT.models') -> None:
70+
def _setitem(self, index: int, item: object) -> None:
7171
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
7272
self._check_classes(self + [item])
7373
self._check_unique_name_fields(self + [item])
@@ -171,7 +171,7 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
171171
inserted into the ClassList and the keyword arguments are discarded.
172172
"""
173173
if obj and kwargs:
174-
warnings.warn('ClassList.insert() called with both object and keyword arguments. '
174+
warnings.warn('ClassList.insert() called with both an object and keyword arguments. '
175175
'The keyword arguments will be ignored.', SyntaxWarning)
176176
if obj:
177177
if not hasattr(self, '_class_handle'):
@@ -193,15 +193,17 @@ def remove(self, item: Union[object, str]) -> None:
193193

194194
def count(self, item: Union[object, str]) -> int:
195195
"""Return the number of times an object appears in the ClassList using either the object itself or its
196-
name_field value."""
196+
name_field value.
197+
"""
197198
item = self._get_item_from_name_field(item)
198199
return self.data.count(item)
199200

200-
def index(self, item: Union[object, str], *args) -> int:
201+
def index(self, item: Union[object, str], offset: bool = False, *args) -> int:
201202
"""Return the index of a particular object in the ClassList using either the object itself or its
202-
name_field value."""
203+
name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
204+
"""
203205
item = self._get_item_from_name_field(item)
204-
return self.data.index(item, *args)
206+
return self.data.index(item, *args) + int(offset)
205207

206208
def extend(self, other: Sequence[object]) -> None:
207209
"""Extend the ClassList by adding another sequence."""

RAT/controls.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,65 @@
1+
from dataclasses import dataclass, field
12
import prettytable
23
from pydantic import BaseModel, Field, field_validator, ValidationError
34
from typing import Literal, Union
45

5-
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
6+
from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies
67
from RAT.utils.custom_errors import custom_pydantic_validation_error
78

89

10+
@dataclass(frozen=True)
11+
class Controls:
12+
"""The full set of controls parameters required for the compiled RAT code."""
13+
# All Procedures
14+
procedure: Procedures = Procedures.Calculate
15+
parallel: Parallel = Parallel.Single
16+
calcSldDuringFit: bool = False
17+
resampleParams: list[float] = field(default_factory=list[0.9, 50.0])
18+
display: Display = Display.Iter
19+
# Simplex
20+
xTolerance: float = 1.0e-6
21+
funcTolerance: float = 1.0e-6
22+
maxFuncEvals: int = 10000
23+
maxIterations: int = 1000
24+
updateFreq: int = -1
25+
updatePlotFreq: int = 1
26+
# DE
27+
populationSize: int = 20
28+
fWeight: float = 0.5
29+
crossoverProbability: float = 0.8
30+
strategy: Strategies = Strategies.RandomWithPerVectorDither.value
31+
targetValue: float = 1.0
32+
numGenerations: int = 500
33+
# NS
34+
nLive: int = 150
35+
nMCMC: float = 0.0
36+
propScale: float = 0.1
37+
nsTolerance: float = 0.1
38+
# Dream
39+
nSamples: int = 50000
40+
nChains: int = 10
41+
jumpProbability: float = 0.5
42+
pUnitGamma: float = 0.2
43+
boundHandling: BoundHandling = BoundHandling.Fold
44+
adaptPCR: bool = False
45+
46+
947
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
1048
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
1149
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
12-
parallel: ParallelOptions = ParallelOptions.Single
50+
parallel: Parallel = Parallel.Single
1351
calcSldDuringFit: bool = False
14-
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
15-
display: DisplayOptions = DisplayOptions.Iter
52+
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
53+
display: Display = Display.Iter
1654

17-
@field_validator("resamPars")
55+
@field_validator("resampleParams")
1856
@classmethod
19-
def check_resamPars(cls, resamPars):
20-
if not 0 < resamPars[0] < 1:
21-
raise ValueError('resamPars[0] must be between 0 and 1')
22-
if resamPars[1] < 0:
23-
raise ValueError('resamPars[1] must be greater than or equal to 0')
24-
return resamPars
57+
def check_resample_params(cls, resampleParams):
58+
if not 0 < resampleParams[0] < 1:
59+
raise ValueError('resampleParams[0] must be between 0 and 1')
60+
if resampleParams[1] < 0:
61+
raise ValueError('resampleParams[1] must be greater than or equal to 0')
62+
return resampleParams
2563

2664
def __repr__(self) -> str:
2765
table = prettytable.PrettyTable()
@@ -30,45 +68,46 @@ def __repr__(self) -> str:
3068
return table.get_string()
3169

3270

33-
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
71+
class Simplex(Calculate):
3472
"""Defines the additional fields for the simplex procedure."""
3573
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
36-
tolX: float = Field(1.0e-6, gt=0.0)
37-
tolFun: float = Field(1.0e-6, gt=0.0)
38-
maxFunEvals: int = Field(10000, gt=0)
39-
maxIter: int = Field(1000, gt=0)
74+
xTolerance: float = Field(1.0e-6, gt=0.0)
75+
funcTolerance: float = Field(1.0e-6, gt=0.0)
76+
maxFuncEvals: int = Field(10000, gt=0)
77+
maxIterations: int = Field(1000, gt=0)
4078
updateFreq: int = -1
41-
updatePlotFreq: int = -1
79+
updatePlotFreq: int = 1
4280

4381

44-
class DE(Calculate, validate_assignment=True, extra='forbid'):
82+
class DE(Calculate):
4583
"""Defines the additional fields for the Differential Evolution procedure."""
4684
procedure: Literal[Procedures.DE] = Procedures.DE
4785
populationSize: int = Field(20, ge=1)
4886
fWeight: float = 0.5
4987
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
50-
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
88+
strategy: Strategies = Strategies.RandomWithPerVectorDither
5189
targetValue: float = Field(1.0, ge=1.0)
5290
numGenerations: int = Field(500, ge=1)
5391

5492

55-
class NS(Calculate, validate_assignment=True, extra='forbid'):
93+
class NS(Calculate):
5694
"""Defines the additional fields for the Nested Sampler procedure."""
5795
procedure: Literal[Procedures.NS] = Procedures.NS
58-
Nlive: int = Field(150, ge=1)
59-
Nmcmc: float = Field(0.0, ge=0.0)
96+
nLive: int = Field(150, ge=1)
97+
nMCMC: float = Field(0.0, ge=0.0)
6098
propScale: float = Field(0.1, gt=0.0, lt=1.0)
6199
nsTolerance: float = Field(0.1, ge=0.0)
62100

63101

64-
class Dream(Calculate, validate_assignment=True, extra='forbid'):
102+
class Dream(Calculate):
65103
"""Defines the additional fields for the Dream procedure."""
66104
procedure: Literal[Procedures.Dream] = Procedures.Dream
67105
nSamples: int = Field(50000, ge=0)
68106
nChains: int = Field(10, gt=0)
69-
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
107+
jumpProbability: float = Field(0.5, gt=0.0, lt=1.0)
70108
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
71-
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold
109+
boundHandling: BoundHandling = BoundHandling.Fold
110+
adaptPCR: bool = False
72111

73112

74113
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\

RAT/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Callable, Union, List
2-
import RAT.rat_core
3-
from RAT.rat_core import EventTypes, PlotEventData, ProgressEventData
2+
from RAT.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData
43

54

65
def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None:
@@ -18,6 +17,7 @@ def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEvent
1817
for callback in callbacks:
1918
callback(data)
2019

20+
2121
def get_event_callback(event_type: EventTypes) -> List[Callable[[Union[str, PlotEventData, ProgressEventData]], None]]:
2222
"""Returns all callbacks registered for the given event type.
2323
@@ -59,5 +59,5 @@ def clear() -> None:
5959
__event_callbacks[key] = set()
6060

6161

62-
__event_impl = RAT.rat_core.EventBridge(notify)
62+
__event_impl = EventBridge(notify)
6363
__event_callbacks = {EventTypes.Message: set(), EventTypes.Plot: set(), EventTypes.Progress: set()}

0 commit comments

Comments
 (0)