Skip to content

Commit 7f2ad53

Browse files
authored
Refactors the controls module, adding custom errors (#17)
* Changes "controlsClass" to factory function "set_controls". * Removes "baseControls" class and fixes bug when initialising procedure field * Adds error check to "set_controls" * Adds formatted error for extra fields to "set_controls" * Introduces logging for error reporting * Substitutes logging for raising errors directly * Adds routine "custom_pydantic_validation_error" to introduce custom error messages when raising a ValidationError
1 parent 101365b commit 7f2ad53

File tree

7 files changed

+380
-340
lines changed

7 files changed

+380
-340
lines changed

RAT/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from RAT.classlist import ClassList
2-
from RAT.controls import Controls
32
from RAT.project import Project
3+
import RAT.controls
4+
import RAT.models

RAT/controls.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import prettytable
2-
from pydantic import BaseModel, Field, field_validator
3-
from typing import Union
2+
from pydantic import BaseModel, Field, field_validator, ValidationError
3+
from typing import Literal, Union
44

55
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
6+
from RAT.utils.custom_errors import custom_pydantic_validation_error
67

78

8-
class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
9-
"""Defines the base class with properties used in all five procedures."""
9+
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
10+
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
11+
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
1012
parallel: ParallelOptions = ParallelOptions.Single
1113
calcSldDuringFit: bool = False
1214
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
@@ -21,15 +23,16 @@ def check_resamPars(cls, resamPars):
2123
raise ValueError('resamPars[1] must be greater than or equal to 0')
2224
return resamPars
2325

24-
25-
class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
26-
"""Defines the class for the calculate procedure."""
27-
procedure: Procedures = Field(Procedures.Calculate, frozen=True)
26+
def __repr__(self) -> str:
27+
table = prettytable.PrettyTable()
28+
table.field_names = ['Property', 'Value']
29+
table.add_rows([[k, v] for k, v in self.__dict__.items()])
30+
return table.get_string()
2831

2932

30-
class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
31-
"""Defines the class for the simplex procedure."""
32-
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
33+
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
34+
"""Defines the additional fields for the simplex procedure."""
35+
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
3336
tolX: float = Field(1.0e-6, gt=0.0)
3437
tolFun: float = Field(1.0e-6, gt=0.0)
3538
maxFunEvals: int = Field(10000, gt=0)
@@ -38,9 +41,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
3841
updatePlotFreq: int = -1
3942

4043

41-
class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
42-
"""Defines the class for the Differential Evolution procedure."""
43-
procedure: Procedures = Field(Procedures.DE, frozen=True)
44+
class DE(Calculate, validate_assignment=True, extra='forbid'):
45+
"""Defines the additional fields for the Differential Evolution procedure."""
46+
procedure: Literal[Procedures.DE] = Procedures.DE
4447
populationSize: int = Field(20, ge=1)
4548
fWeight: float = 0.5
4649
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
@@ -49,52 +52,48 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
4952
numGenerations: int = Field(500, ge=1)
5053

5154

52-
class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
53-
"""Defines the class for the Nested Sampler procedure."""
54-
procedure: Procedures = Field(Procedures.NS, frozen=True)
55+
class NS(Calculate, validate_assignment=True, extra='forbid'):
56+
"""Defines the additional fields for the Nested Sampler procedure."""
57+
procedure: Literal[Procedures.NS] = Procedures.NS
5558
Nlive: int = Field(150, ge=1)
5659
Nmcmc: float = Field(0.0, ge=0.0)
5760
propScale: float = Field(0.1, gt=0.0, lt=1.0)
5861
nsTolerance: float = Field(0.1, ge=0.0)
5962

6063

61-
class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
62-
"""Defines the class for the Dream procedure."""
63-
procedure: Procedures = Field(Procedures.Dream, frozen=True)
64+
class Dream(Calculate, validate_assignment=True, extra='forbid'):
65+
"""Defines the additional fields for the Dream procedure."""
66+
procedure: Literal[Procedures.Dream] = Procedures.Dream
6467
nSamples: int = Field(50000, ge=0)
6568
nChains: int = Field(10, gt=0)
6669
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
6770
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
6871
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold
6972

7073

71-
class Controls:
72-
73-
def __init__(self,
74-
procedure: Procedures = Procedures.Calculate,
75-
**properties) -> None:
76-
77-
if procedure == Procedures.Calculate:
78-
self.controls = Calculate(**properties)
79-
elif procedure == Procedures.Simplex:
80-
self.controls = Simplex(**properties)
81-
elif procedure == Procedures.DE:
82-
self.controls = DE(**properties)
83-
elif procedure == Procedures.NS:
84-
self.controls = NS(**properties)
85-
elif procedure == Procedures.Dream:
86-
self.controls = Dream(**properties)
87-
88-
@property
89-
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
90-
return self._controls
91-
92-
@controls.setter
93-
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
94-
self._controls = value
95-
96-
def __repr__(self) -> str:
97-
table = prettytable.PrettyTable()
98-
table.field_names = ['Property', 'Value']
99-
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
100-
return table.get_string()
74+
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
75+
-> Union[Calculate, Simplex, DE, NS, Dream]:
76+
"""Returns the appropriate controls model given the specified procedure."""
77+
controls = {
78+
Procedures.Calculate: Calculate,
79+
Procedures.Simplex: Simplex,
80+
Procedures.DE: DE,
81+
Procedures.NS: NS,
82+
Procedures.Dream: Dream
83+
}
84+
85+
try:
86+
model = controls[procedure](**properties)
87+
except KeyError:
88+
members = list(Procedures.__members__.values())
89+
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
90+
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
91+
except ValidationError as exc:
92+
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
93+
f' controls procedure are:\n '
94+
f'{", ".join(controls[procedure].model_fields.keys())}\n'
95+
}
96+
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
97+
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
98+
99+
return model

RAT/project.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from RAT.classlist import ClassList
1212
import RAT.models
13-
from RAT.utils.custom_errors import formatted_pydantic_error
13+
from RAT.utils.custom_errors import custom_pydantic_validation_error
1414

1515
try:
1616
from enum import StrEnum
@@ -524,11 +524,10 @@ def wrapped_func(*args, **kwargs):
524524
try:
525525
return_value = func(*args, **kwargs)
526526
Project.model_validate(self)
527-
except ValidationError as e:
527+
except ValidationError as exc:
528528
setattr(class_list, 'data', previous_state)
529-
error_string = formatted_pydantic_error(e)
530-
# Use ANSI escape sequences to print error text in red
531-
print('\033[31m' + error_string + '\033[0m')
529+
custom_error_list = custom_pydantic_validation_error(exc.errors())
530+
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
532531
except (TypeError, ValueError):
533532
setattr(class_list, 'data', previous_state)
534533
raise

RAT/utils/custom_errors.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
11
"""Defines routines for custom error handling in RAT."""
2+
import pydantic_core
23

3-
from pydantic import ValidationError
44

5+
def custom_pydantic_validation_error(error_list: list[pydantic_core.ErrorDetails], custom_errors: dict[str, str] = None
6+
) -> list[pydantic_core.ErrorDetails]:
7+
"""Run through the list of errors generated from a pydantic ValidationError, substituting the standard error for a
8+
PydanticCustomError for a given set of error types.
59
6-
def formatted_pydantic_error(error: ValidationError) -> str:
7-
"""Write a custom string format for pydantic validation errors.
10+
For errors that do not have a custom error message defined, we redefine them using a PydanticCustomError to remove
11+
the url from the error message.
812
913
Parameters
1014
----------
11-
error : pydantic.ValidationError
12-
A ValidationError produced by a pydantic model
15+
error_list : list[pydantic_core.ErrorDetails]
16+
A list of errors produced by pydantic.ValidationError.errors().
17+
custom_errors: dict[str, str], optional
18+
A dict of custom error messages for given error types.
1319
1420
Returns
1521
-------
16-
error_str : str
17-
A string giving details of the ValidationError in a custom format.
22+
new_error : list[pydantic_core.ErrorDetails]
23+
A list of errors including PydanticCustomErrors in place of the error types in custom_errors.
1824
"""
19-
num_errors = error.error_count()
20-
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'
21-
for this_error in error.errors():
22-
error_str += '\n'
23-
if this_error['loc']:
24-
error_str += ' '.join(this_error['loc']) + '\n'
25-
error_str += ' ' + this_error['msg']
26-
return error_str
25+
if custom_errors is None:
26+
custom_errors = {}
27+
custom_error_list = []
28+
for error in error_list:
29+
if error['type'] in custom_errors:
30+
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], custom_errors[error['type']])
31+
else:
32+
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], error['msg'])
33+
error['type'] = RAT_custom_error
34+
custom_error_list.append(error)
35+
36+
return custom_error_list

0 commit comments

Comments
 (0)