Skip to content

Commit 6f329ed

Browse files
authored
added the controlsClass (#10)
* added the test file for controls class * adds enums * adds the procedure classes with doc strings and typings * added controls class with input validation * added the display methods for all the classes * added tests for the procedure classes * added tests for controls class * updating typing in control classes * updated enums in controls and tests * updated typings to literal and added tests to check property types and updated docs * converted procedures to pydantic classes * added model verification * added __repr__ method for control class * addressed the review comments * addressed comments
1 parent eb2ec43 commit 6f329ed

File tree

3 files changed

+671
-0
lines changed

3 files changed

+671
-0
lines changed

RAT/controls.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import tabulate
2+
from typing import Union
3+
from pydantic import BaseModel, Field, field_validator
4+
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
5+
6+
7+
class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'):
8+
"""
9+
Defines the base class with properties used in all five procedures.
10+
"""
11+
parallel: ParallelOptions = ParallelOptions.Single
12+
calcSldDuringFit: bool = False
13+
resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2)
14+
display: DisplayOptions = DisplayOptions.Iter
15+
16+
@field_validator("resamPars")
17+
def check_resamPars(cls, resamPars):
18+
if not 0 < resamPars[0] < 1:
19+
raise ValueError('resamPars[0] must be between 0 and 1')
20+
if resamPars[1] < 0:
21+
raise ValueError('resamPars[1] must be greater than or equal to 0')
22+
return resamPars
23+
24+
25+
class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'):
26+
"""
27+
Defines the class for the calculate procedure.
28+
"""
29+
procedure: Procedures = Field(Procedures.Calculate, frozen = True)
30+
31+
32+
class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'):
33+
"""
34+
Defines the class for the simplex procedure.
35+
"""
36+
procedure: Procedures = Field(Procedures.Simplex, frozen = True)
37+
tolX: float = Field(1e-6, gt = 0)
38+
tolFun: float = Field(1e-6, gt = 0)
39+
maxFunEvals: int = Field(10000, gt = 0)
40+
maxIter: int = Field(1000, gt = 0)
41+
updateFreq: int = -1
42+
updatePlotFreq: int = -1
43+
44+
45+
class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'):
46+
"""
47+
Defines the class for the Differential Evolution procedure.
48+
"""
49+
procedure: Procedures = Field(Procedures.DE, frozen = True)
50+
populationSize: int = Field(20, ge = 1)
51+
fWeight: float = 0.5
52+
crossoverProbability: float = Field(0.8, gt = 0, lt = 1)
53+
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
54+
targetValue: float = Field(1.0, ge = 1)
55+
numGenerations: int = Field(500, ge = 1)
56+
57+
58+
class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'):
59+
"""
60+
Defines the class for the Nested Sampler procedure.
61+
"""
62+
procedure: Procedures = Field(Procedures.NS, frozen = True)
63+
Nlive: int = Field(150, ge = 1)
64+
Nmcmc: float = Field(0.0, ge = 0)
65+
propScale: float = Field(0.1, gt = 0, lt = 1)
66+
nsTolerance: float = Field(0.1, ge = 0)
67+
68+
69+
class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'):
70+
"""
71+
Defines the class for the Dream procedure
72+
"""
73+
procedure: Procedures = Field(Procedures.Dream, frozen = True)
74+
nSamples: int = Field(50000, ge = 0)
75+
nChains: int = Field(10, gt = 0)
76+
jumpProb: float = Field(0.5, gt = 0, lt = 1)
77+
pUnitGamma:float = Field(0.2, gt = 0, lt = 1)
78+
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold
79+
80+
81+
class ControlsClass:
82+
83+
def __init__(self,
84+
procedure: Procedures = Procedures.Calculate,
85+
**properties) -> None:
86+
87+
if procedure == Procedures.Calculate:
88+
self.controls = Calculate(**properties)
89+
elif procedure == Procedures.Simplex:
90+
self.controls = Simplex(**properties)
91+
elif procedure == Procedures.DE:
92+
self.controls = DE(**properties)
93+
elif procedure == Procedures.NS:
94+
self.controls = NS(**properties)
95+
elif procedure == Procedures.Dream:
96+
self.controls = Dream(**properties)
97+
98+
@property
99+
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
100+
return self._controls
101+
102+
@controls.setter
103+
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
104+
self._controls = value
105+
106+
def __repr__(self) -> str:
107+
properties = [["Property", "Value"]] +\
108+
[[k, v] for k, v in self._controls.__dict__.items()]
109+
table = tabulate.tabulate(properties, headers="firstrow")
110+
return table

RAT/utils/enums.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from enum import Enum
2+
try:
3+
from enum import StrEnum
4+
except ImportError:
5+
from strenum import StrEnum
6+
7+
8+
class ParallelOptions(StrEnum):
9+
"""Defines the avaliable options for parallelization"""
10+
Single = 'single'
11+
Points = 'points'
12+
Contrasts = 'contrasts'
13+
All = 'all'
14+
15+
16+
class Procedures(StrEnum):
17+
"""Defines the avaliable options for procedures"""
18+
Calculate = 'calculate'
19+
Simplex = 'simplex'
20+
DE = 'de'
21+
NS = 'ns'
22+
Dream = 'dream'
23+
24+
25+
class DisplayOptions(StrEnum):
26+
"""Defines the avaliable options for display"""
27+
Off = 'off'
28+
Iter = 'iter'
29+
Notify = 'notify'
30+
Final = 'final'
31+
32+
33+
class BoundHandlingOptions(StrEnum):
34+
"""Defines the avaliable options for bound handling"""
35+
Off = 'off'
36+
Reflect = 'reflect'
37+
Bound = 'bound'
38+
Fold = 'fold'
39+
40+
41+
class StrategyOptions(Enum):
42+
"""Defines the avaliable options for strategies"""
43+
Random = 1
44+
LocalToBest = 2
45+
BestWithJitter = 3
46+
RandomWithPerVectorDither = 4
47+
RandomWithPerGenerationDither = 5
48+
RandomEitherOrAlgorithm = 6

0 commit comments

Comments
 (0)