11import prettytable
2- from pydantic import BaseModel , Field , field_validator
3- from typing import Union
2+ from pydantic import BaseModel , Field , field_validator , ValidationError
3+ from typing import Literal , Union
44
55from RAT .utils .enums import ParallelOptions , Procedures , DisplayOptions , BoundHandlingOptions , StrategyOptions
6+ from RAT .utils .custom_errors import custom_pydantic_validation_error
67
78
8- class BaseProcedure (BaseModel , validate_assignment = True , extra = 'forbid' ):
9- """Defines the base class with properties used in all five procedures."""
9+ class Calculate (BaseModel , validate_assignment = True , extra = 'forbid' ):
10+ """Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
11+ procedure : Literal [Procedures .Calculate ] = Procedures .Calculate
1012 parallel : ParallelOptions = ParallelOptions .Single
1113 calcSldDuringFit : bool = False
1214 resamPars : list [float ] = Field ([0.9 , 50 ], min_length = 2 , max_length = 2 )
@@ -21,15 +23,16 @@ def check_resamPars(cls, resamPars):
2123 raise ValueError ('resamPars[1] must be greater than or equal to 0' )
2224 return resamPars
2325
24-
25- class Calculate (BaseProcedure , validate_assignment = True , extra = 'forbid' ):
26- """Defines the class for the calculate procedure."""
27- procedure : Procedures = Field (Procedures .Calculate , frozen = True )
26+ def __repr__ (self ) -> str :
27+ table = prettytable .PrettyTable ()
28+ table .field_names = ['Property' , 'Value' ]
29+ table .add_rows ([[k , v ] for k , v in self .__dict__ .items ()])
30+ return table .get_string ()
2831
2932
30- class Simplex (BaseProcedure , validate_assignment = True , extra = 'forbid' ):
31- """Defines the class for the simplex procedure."""
32- procedure : Procedures = Field ( Procedures .Simplex , frozen = True )
33+ class Simplex (Calculate , validate_assignment = True , extra = 'forbid' ):
34+ """Defines the additional fields for the simplex procedure."""
35+ procedure : Literal [ Procedures . Simplex ] = Procedures .Simplex
3336 tolX : float = Field (1.0e-6 , gt = 0.0 )
3437 tolFun : float = Field (1.0e-6 , gt = 0.0 )
3538 maxFunEvals : int = Field (10000 , gt = 0 )
@@ -38,9 +41,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
3841 updatePlotFreq : int = - 1
3942
4043
41- class DE (BaseProcedure , validate_assignment = True , extra = 'forbid' ):
42- """Defines the class for the Differential Evolution procedure."""
43- procedure : Procedures = Field ( Procedures .DE , frozen = True )
44+ class DE (Calculate , validate_assignment = True , extra = 'forbid' ):
45+ """Defines the additional fields for the Differential Evolution procedure."""
46+ procedure : Literal [ Procedures . DE ] = Procedures .DE
4447 populationSize : int = Field (20 , ge = 1 )
4548 fWeight : float = 0.5
4649 crossoverProbability : float = Field (0.8 , gt = 0.0 , lt = 1.0 )
@@ -49,52 +52,48 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
4952 numGenerations : int = Field (500 , ge = 1 )
5053
5154
52- class NS (BaseProcedure , validate_assignment = True , extra = 'forbid' ):
53- """Defines the class for the Nested Sampler procedure."""
54- procedure : Procedures = Field ( Procedures .NS , frozen = True )
55+ class NS (Calculate , validate_assignment = True , extra = 'forbid' ):
56+ """Defines the additional fields for the Nested Sampler procedure."""
57+ procedure : Literal [ Procedures . NS ] = Procedures .NS
5558 Nlive : int = Field (150 , ge = 1 )
5659 Nmcmc : float = Field (0.0 , ge = 0.0 )
5760 propScale : float = Field (0.1 , gt = 0.0 , lt = 1.0 )
5861 nsTolerance : float = Field (0.1 , ge = 0.0 )
5962
6063
61- class Dream (BaseProcedure , validate_assignment = True , extra = 'forbid' ):
62- """Defines the class for the Dream procedure."""
63- procedure : Procedures = Field ( Procedures .Dream , frozen = True )
64+ class Dream (Calculate , validate_assignment = True , extra = 'forbid' ):
65+ """Defines the additional fields for the Dream procedure."""
66+ procedure : Literal [ Procedures . Dream ] = Procedures .Dream
6467 nSamples : int = Field (50000 , ge = 0 )
6568 nChains : int = Field (10 , gt = 0 )
6669 jumpProb : float = Field (0.5 , gt = 0.0 , lt = 1.0 )
6770 pUnitGamma : float = Field (0.2 , gt = 0.0 , lt = 1.0 )
6871 boundHandling : BoundHandlingOptions = BoundHandlingOptions .Fold
6972
7073
71- class Controls :
72-
73- def __init__ (self ,
74- procedure : Procedures = Procedures .Calculate ,
75- ** properties ) -> None :
76-
77- if procedure == Procedures .Calculate :
78- self .controls = Calculate (** properties )
79- elif procedure == Procedures .Simplex :
80- self .controls = Simplex (** properties )
81- elif procedure == Procedures .DE :
82- self .controls = DE (** properties )
83- elif procedure == Procedures .NS :
84- self .controls = NS (** properties )
85- elif procedure == Procedures .Dream :
86- self .controls = Dream (** properties )
87-
88- @property
89- def controls (self ) -> Union [Calculate , Simplex , DE , NS , Dream ]:
90- return self ._controls
91-
92- @controls .setter
93- def controls (self , value : Union [Calculate , Simplex , DE , NS , Dream ]) -> None :
94- self ._controls = value
95-
96- def __repr__ (self ) -> str :
97- table = prettytable .PrettyTable ()
98- table .field_names = ['Property' , 'Value' ]
99- table .add_rows ([[k , v ] for k , v in self ._controls .__dict__ .items ()])
100- return table .get_string ()
74+ def set_controls (procedure : Procedures = Procedures .Calculate , ** properties )\
75+ -> Union [Calculate , Simplex , DE , NS , Dream ]:
76+ """Returns the appropriate controls model given the specified procedure."""
77+ controls = {
78+ Procedures .Calculate : Calculate ,
79+ Procedures .Simplex : Simplex ,
80+ Procedures .DE : DE ,
81+ Procedures .NS : NS ,
82+ Procedures .Dream : Dream
83+ }
84+
85+ try :
86+ model = controls [procedure ](** properties )
87+ except KeyError :
88+ members = list (Procedures .__members__ .values ())
89+ allowed_values = f'{ ", " .join ([repr (member .value ) for member in members [:- 1 ]])} or { members [- 1 ].value !r} '
90+ raise ValueError (f'The controls procedure must be one of: { allowed_values } ' ) from None
91+ except ValidationError as exc :
92+ custom_error_msgs = {'extra_forbidden' : f'Extra inputs are not permitted. The fields for the { procedure } '
93+ f' controls procedure are:\n '
94+ f'{ ", " .join (controls [procedure ].model_fields .keys ())} \n '
95+ }
96+ custom_error_list = custom_pydantic_validation_error (exc .errors (), custom_error_msgs )
97+ raise ValidationError .from_exception_data (exc .title , custom_error_list ) from None
98+
99+ return model
0 commit comments