|
1 | | -from dataclasses import dataclass, field |
2 | | -from typing import Literal, Union |
| 1 | +import warnings |
3 | 2 |
|
4 | 3 | import prettytable |
5 | | -from pydantic import BaseModel, Field, ValidationError, field_validator |
| 4 | +from pydantic import ( |
| 5 | + BaseModel, |
| 6 | + Field, |
| 7 | + ValidationError, |
| 8 | + ValidatorFunctionWrapHandler, |
| 9 | + field_validator, |
| 10 | + model_serializer, |
| 11 | + model_validator, |
| 12 | +) |
6 | 13 |
|
7 | 14 | from RATpy.utils.custom_errors import custom_pydantic_validation_error |
8 | 15 | from RATpy.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies |
9 | 16 |
|
10 | | - |
11 | | -@dataclass(frozen=True) |
12 | | -class Controls: |
13 | | - """The full set of controls parameters required for the compiled RAT code.""" |
| 17 | +common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"] |
| 18 | +update_fields = ["updateFreq", "updatePlotFreq"] |
| 19 | +fields = { |
| 20 | + "calculate": common_fields, |
| 21 | + "simplex": [*common_fields, "xTolerance", "funcTolerance", "maxFuncEvals", "maxIterations", *update_fields], |
| 22 | + "de": [ |
| 23 | + *common_fields, |
| 24 | + "populationSize", |
| 25 | + "fWeight", |
| 26 | + "crossoverProbability", |
| 27 | + "strategy", |
| 28 | + "targetValue", |
| 29 | + "numGenerations", |
| 30 | + *update_fields, |
| 31 | + ], |
| 32 | + "ns": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"], |
| 33 | + "dream": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"], |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +class Controls(BaseModel, validate_assignment=True, extra="forbid"): |
| 38 | + """The full set of controls parameters for all five procedures that are required for the compiled RAT code.""" |
14 | 39 |
|
15 | 40 | # All Procedures |
16 | 41 | procedure: Procedures = Procedures.Calculate |
17 | 42 | parallel: Parallel = Parallel.Single |
18 | 43 | calcSldDuringFit: bool = False |
19 | | - resampleParams: list[float] = field(default_factory=list[0.9, 50.0]) |
20 | | - display: Display = Display.Iter |
21 | | - # Simplex |
22 | | - xTolerance: float = 1.0e-6 |
23 | | - funcTolerance: float = 1.0e-6 |
24 | | - maxFuncEvals: int = 10000 |
25 | | - maxIterations: int = 1000 |
26 | | - updateFreq: int = -1 |
27 | | - updatePlotFreq: int = 1 |
28 | | - # DE |
29 | | - populationSize: int = 20 |
30 | | - fWeight: float = 0.5 |
31 | | - crossoverProbability: float = 0.8 |
32 | | - strategy: Strategies = Strategies.RandomWithPerVectorDither.value |
33 | | - targetValue: float = 1.0 |
34 | | - numGenerations: int = 500 |
35 | | - # NS |
36 | | - nLive: int = 150 |
37 | | - nMCMC: float = 0.0 |
38 | | - propScale: float = 0.1 |
39 | | - nsTolerance: float = 0.1 |
40 | | - # Dream |
41 | | - nSamples: int = 20000 |
42 | | - nChains: int = 10 |
43 | | - jumpProbability: float = 0.5 |
44 | | - pUnitGamma: float = 0.2 |
45 | | - boundHandling: BoundHandling = BoundHandling.Reflect |
46 | | - adaptPCR: bool = True |
47 | | - |
48 | | - |
49 | | -class Calculate(BaseModel, validate_assignment=True, extra="forbid"): |
50 | | - """Defines the class for the calculate procedure, which includes the properties used in all five procedures.""" |
51 | | - |
52 | | - procedure: Literal[Procedures.Calculate] = Procedures.Calculate |
53 | | - parallel: Parallel = Parallel.Single |
54 | | - calcSldDuringFit: bool = False |
55 | 44 | resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2) |
56 | 45 | display: Display = Display.Iter |
57 | | - |
58 | | - @field_validator("resampleParams") |
59 | | - @classmethod |
60 | | - def check_resample_params(cls, resampleParams): |
61 | | - if not 0 < resampleParams[0] < 1: |
62 | | - raise ValueError("resampleParams[0] must be between 0 and 1") |
63 | | - if resampleParams[1] < 0: |
64 | | - raise ValueError("resampleParams[1] must be greater than or equal to 0") |
65 | | - return resampleParams |
66 | | - |
67 | | - def __repr__(self) -> str: |
68 | | - table = prettytable.PrettyTable() |
69 | | - table.field_names = ["Property", "Value"] |
70 | | - table.add_rows([[k, v] for k, v in self.__dict__.items()]) |
71 | | - return table.get_string() |
72 | | - |
73 | | - |
74 | | -class Simplex(Calculate): |
75 | | - """Defines the additional fields for the simplex procedure.""" |
76 | | - |
77 | | - procedure: Literal[Procedures.Simplex] = Procedures.Simplex |
| 46 | + # Simplex |
78 | 47 | xTolerance: float = Field(1.0e-6, gt=0.0) |
79 | 48 | funcTolerance: float = Field(1.0e-6, gt=0.0) |
80 | 49 | maxFuncEvals: int = Field(10000, gt=0) |
81 | 50 | maxIterations: int = Field(1000, gt=0) |
82 | | - updateFreq: int = -1 |
83 | | - updatePlotFreq: int = 1 |
84 | | - |
85 | | - |
86 | | -class DE(Calculate): |
87 | | - """Defines the additional fields for the Differential Evolution procedure.""" |
88 | | - |
89 | | - procedure: Literal[Procedures.DE] = Procedures.DE |
| 51 | + # Simplex and DE |
| 52 | + updateFreq: int = 1 |
| 53 | + updatePlotFreq: int = 20 |
| 54 | + # DE |
90 | 55 | populationSize: int = Field(20, ge=1) |
91 | 56 | fWeight: float = 0.5 |
92 | 57 | crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0) |
93 | 58 | strategy: Strategies = Strategies.RandomWithPerVectorDither |
94 | 59 | targetValue: float = Field(1.0, ge=1.0) |
95 | 60 | numGenerations: int = Field(500, ge=1) |
96 | | - |
97 | | - |
98 | | -class NS(Calculate): |
99 | | - """Defines the additional fields for the Nested Sampler procedure.""" |
100 | | - |
101 | | - procedure: Literal[Procedures.NS] = Procedures.NS |
| 61 | + # NS |
102 | 62 | nLive: int = Field(150, ge=1) |
103 | 63 | nMCMC: float = Field(0.0, ge=0.0) |
104 | 64 | propScale: float = Field(0.1, gt=0.0, lt=1.0) |
105 | 65 | nsTolerance: float = Field(0.1, ge=0.0) |
106 | | - |
107 | | - |
108 | | -class Dream(Calculate): |
109 | | - """Defines the additional fields for the Dream procedure.""" |
110 | | - |
111 | | - procedure: Literal[Procedures.Dream] = Procedures.Dream |
| 66 | + # Dream |
112 | 67 | nSamples: int = Field(20000, ge=0) |
113 | 68 | nChains: int = Field(10, gt=0) |
114 | 69 | jumpProbability: float = Field(0.5, gt=0.0, lt=1.0) |
115 | 70 | pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0) |
116 | 71 | boundHandling: BoundHandling = BoundHandling.Reflect |
117 | 72 | adaptPCR: bool = True |
118 | 73 |
|
| 74 | + @model_validator(mode="wrap") |
| 75 | + def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandler) -> "Controls": |
| 76 | + """Raise a warning if the user sets fields that apply to other procedures.""" |
| 77 | + model_input = self |
| 78 | + try: |
| 79 | + input_dict = model_input.__dict__ |
| 80 | + except AttributeError: |
| 81 | + input_dict = model_input |
| 82 | + |
| 83 | + validated_self = None |
| 84 | + try: |
| 85 | + validated_self = handler(self) |
| 86 | + except ValidationError as exc: |
| 87 | + procedure = input_dict.get("procedure", Procedures.Calculate) |
| 88 | + custom_error_msgs = { |
| 89 | + "extra_forbidden": f'Extra inputs are not permitted. The fields for the "{procedure}"' |
| 90 | + f' controls procedure are:\n ' |
| 91 | + f'{", ".join(fields.get("procedure", []))}\n', |
| 92 | + } |
| 93 | + custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs) |
| 94 | + raise ValidationError.from_exception_data(exc.title, custom_error_list) from None |
| 95 | + |
| 96 | + if isinstance(model_input, validated_self.__class__): |
| 97 | + # This is for changing fields in a defined model |
| 98 | + changed_fields = [key for key in input_dict if input_dict[key] != validated_self.__dict__[key]] |
| 99 | + elif isinstance(model_input, dict): |
| 100 | + # This is for a newly-defined model |
| 101 | + changed_fields = input_dict.keys() |
| 102 | + else: |
| 103 | + raise ValueError('The input to the "Controls" model is invalid.') |
| 104 | + |
| 105 | + new_procedure = validated_self.procedure |
| 106 | + allowed_fields = fields[new_procedure] |
| 107 | + for field in changed_fields: |
| 108 | + if field not in allowed_fields: |
| 109 | + incorrect_procedures = [key for (key, value) in fields.items() if field in value] |
| 110 | + warnings.warn( |
| 111 | + f'\nThe current controls procedure is "{new_procedure}", but the property' |
| 112 | + f' "{field}" applies instead to the {", ".join(incorrect_procedures)} procedure.\n\n' |
| 113 | + f' The fields for the "{new_procedure}" controls procedure are:\n' |
| 114 | + f' {", ".join(fields[new_procedure])}\n', |
| 115 | + stacklevel=2, |
| 116 | + ) |
| 117 | + |
| 118 | + return validated_self |
| 119 | + |
| 120 | + @field_validator("resampleParams") |
| 121 | + @classmethod |
| 122 | + def check_resample_params(cls, values: list[float]) -> list[float]: |
| 123 | + """Make sure each of the two values of resampleParams satisfy their conditions.""" |
| 124 | + if not 0 < values[0] < 1: |
| 125 | + raise ValueError("resampleParams[0] must be between 0 and 1") |
| 126 | + if values[1] < 0: |
| 127 | + raise ValueError("resampleParams[1] must be greater than or equal to 0") |
| 128 | + return values |
119 | 129 |
|
120 | | -def set_controls( |
121 | | - procedure: Procedures = Procedures.Calculate, |
122 | | - **properties, |
123 | | -) -> Union[Calculate, Simplex, DE, NS, Dream]: |
124 | | - """Returns the appropriate controls model given the specified procedure.""" |
125 | | - controls = { |
126 | | - Procedures.Calculate: Calculate, |
127 | | - Procedures.Simplex: Simplex, |
128 | | - Procedures.DE: DE, |
129 | | - Procedures.NS: NS, |
130 | | - Procedures.Dream: Dream, |
131 | | - } |
| 130 | + @model_serializer |
| 131 | + def serialize(self): |
| 132 | + """Filter fields so only those applying to the chosen procedure are serialized.""" |
| 133 | + return {model_field: getattr(self, model_field) for model_field in fields[self.procedure]} |
132 | 134 |
|
133 | | - try: |
134 | | - model = controls[procedure](**properties) |
135 | | - except KeyError: |
136 | | - members = list(Procedures.__members__.values()) |
137 | | - allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}' |
138 | | - raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None |
139 | | - except ValidationError as exc: |
140 | | - custom_error_msgs = { |
141 | | - "extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}' |
142 | | - f' controls procedure are:\n ' |
143 | | - f'{", ".join(controls[procedure].model_fields.keys())}\n', |
144 | | - } |
145 | | - custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs) |
146 | | - raise ValidationError.from_exception_data(exc.title, custom_error_list) from None |
| 135 | + def __repr__(self) -> str: |
| 136 | + fields_repr = ", ".join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.model_dump().items()) |
| 137 | + return f"{self.__repr_name__()}({fields_repr})" |
147 | 138 |
|
148 | | - return model |
| 139 | + def __str__(self) -> str: |
| 140 | + table = prettytable.PrettyTable() |
| 141 | + table.field_names = ["Property", "Value"] |
| 142 | + table.add_rows([[k, v] for k, v in self.model_dump().items()]) |
| 143 | + return table.get_string() |
0 commit comments