|
| 1 | +import tabulate |
| 2 | +from typing import Union |
| 3 | +from pydantic import BaseModel, Field, field_validator |
| 4 | +from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions |
| 5 | + |
| 6 | + |
| 7 | +class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'): |
| 8 | + """ |
| 9 | + Defines the base class with properties used in all five procedures. |
| 10 | + """ |
| 11 | + parallel: ParallelOptions = ParallelOptions.Single |
| 12 | + calcSldDuringFit: bool = False |
| 13 | + resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2) |
| 14 | + display: DisplayOptions = DisplayOptions.Iter |
| 15 | + |
| 16 | + @field_validator("resamPars") |
| 17 | + def check_resamPars(cls, resamPars): |
| 18 | + if not 0 < resamPars[0] < 1: |
| 19 | + raise ValueError('resamPars[0] must be between 0 and 1') |
| 20 | + if resamPars[1] < 0: |
| 21 | + raise ValueError('resamPars[1] must be greater than or equal to 0') |
| 22 | + return resamPars |
| 23 | + |
| 24 | + |
| 25 | +class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'): |
| 26 | + """ |
| 27 | + Defines the class for the calculate procedure. |
| 28 | + """ |
| 29 | + procedure: Procedures = Field(Procedures.Calculate, frozen = True) |
| 30 | + |
| 31 | + |
| 32 | +class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'): |
| 33 | + """ |
| 34 | + Defines the class for the simplex procedure. |
| 35 | + """ |
| 36 | + procedure: Procedures = Field(Procedures.Simplex, frozen = True) |
| 37 | + tolX: float = Field(1e-6, gt = 0) |
| 38 | + tolFun: float = Field(1e-6, gt = 0) |
| 39 | + maxFunEvals: int = Field(10000, gt = 0) |
| 40 | + maxIter: int = Field(1000, gt = 0) |
| 41 | + updateFreq: int = -1 |
| 42 | + updatePlotFreq: int = -1 |
| 43 | + |
| 44 | + |
| 45 | +class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'): |
| 46 | + """ |
| 47 | + Defines the class for the Differential Evolution procedure. |
| 48 | + """ |
| 49 | + procedure: Procedures = Field(Procedures.DE, frozen = True) |
| 50 | + populationSize: int = Field(20, ge = 1) |
| 51 | + fWeight: float = 0.5 |
| 52 | + crossoverProbability: float = Field(0.8, gt = 0, lt = 1) |
| 53 | + strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither |
| 54 | + targetValue: float = Field(1.0, ge = 1) |
| 55 | + numGenerations: int = Field(500, ge = 1) |
| 56 | + |
| 57 | + |
| 58 | +class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'): |
| 59 | + """ |
| 60 | + Defines the class for the Nested Sampler procedure. |
| 61 | + """ |
| 62 | + procedure: Procedures = Field(Procedures.NS, frozen = True) |
| 63 | + Nlive: int = Field(150, ge = 1) |
| 64 | + Nmcmc: float = Field(0.0, ge = 0) |
| 65 | + propScale: float = Field(0.1, gt = 0, lt = 1) |
| 66 | + nsTolerance: float = Field(0.1, ge = 0) |
| 67 | + |
| 68 | + |
| 69 | +class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'): |
| 70 | + """ |
| 71 | + Defines the class for the Dream procedure |
| 72 | + """ |
| 73 | + procedure: Procedures = Field(Procedures.Dream, frozen = True) |
| 74 | + nSamples: int = Field(50000, ge = 0) |
| 75 | + nChains: int = Field(10, gt = 0) |
| 76 | + jumpProb: float = Field(0.5, gt = 0, lt = 1) |
| 77 | + pUnitGamma:float = Field(0.2, gt = 0, lt = 1) |
| 78 | + boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold |
| 79 | + |
| 80 | + |
| 81 | +class ControlsClass: |
| 82 | + |
| 83 | + def __init__(self, |
| 84 | + procedure: Procedures = Procedures.Calculate, |
| 85 | + **properties) -> None: |
| 86 | + |
| 87 | + if procedure == Procedures.Calculate: |
| 88 | + self.controls = Calculate(**properties) |
| 89 | + elif procedure == Procedures.Simplex: |
| 90 | + self.controls = Simplex(**properties) |
| 91 | + elif procedure == Procedures.DE: |
| 92 | + self.controls = DE(**properties) |
| 93 | + elif procedure == Procedures.NS: |
| 94 | + self.controls = NS(**properties) |
| 95 | + elif procedure == Procedures.Dream: |
| 96 | + self.controls = Dream(**properties) |
| 97 | + |
| 98 | + @property |
| 99 | + def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]: |
| 100 | + return self._controls |
| 101 | + |
| 102 | + @controls.setter |
| 103 | + def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None: |
| 104 | + self._controls = value |
| 105 | + |
| 106 | + def __repr__(self) -> str: |
| 107 | + properties = [["Property", "Value"]] +\ |
| 108 | + [[k, v] for k, v in self._controls.__dict__.items()] |
| 109 | + table = tabulate.tabulate(properties, headers="firstrow") |
| 110 | + return table |
0 commit comments