22
33import numpy as np
44from pydantic import BaseModel , Field , ValidationInfo , field_validator , model_validator
5+ from typing import Any
56
67try :
78 from enum import StrEnum
@@ -43,7 +44,6 @@ class Languages(StrEnum):
4344class Priors (StrEnum ):
4445 Uniform = 'uniform'
4546 Gaussian = 'gaussian'
46- Jeffreys = 'jeffreys'
4747
4848
4949class Types (StrEnum ):
@@ -52,7 +52,19 @@ class Types(StrEnum):
5252 Function = 'function'
5353
5454
55- class Background (BaseModel , validate_assignment = True , extra = 'forbid' ):
55+ class RATModel (BaseModel ):
56+ """A BaseModel where enums are represented by their value."""
57+ def __repr__ (self ):
58+ fields_repr = (', ' .join (repr (v ) if a is None else
59+ f'{ a } ={ v .value !r} ' if isinstance (v , StrEnum ) else
60+ f'{ a } ={ v !r} '
61+ for a , v in self .__repr_args__ ()
62+ )
63+ )
64+ return f'{ self .__repr_name__ ()} ({ fields_repr } )'
65+
66+
67+ class Background (RATModel , validate_assignment = True , extra = 'forbid' ):
5668 """Defines the Backgrounds in RAT."""
5769 name : str = Field (default_factory = lambda : 'New Background ' + next (background_number ), min_length = 1 )
5870 type : Types = Types .Constant
@@ -63,7 +75,7 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):
6375 value_5 : str = ''
6476
6577
66- class Contrast (BaseModel , validate_assignment = True , extra = 'forbid' ):
78+ class Contrast (RATModel , validate_assignment = True , extra = 'forbid' ):
6779 """Groups together all of the components of the model."""
6880 name : str = Field (default_factory = lambda : 'New Contrast ' + next (contrast_number ), min_length = 1 )
6981 data : str = ''
@@ -76,7 +88,7 @@ class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
7688 model : list [str ] = []
7789
7890
79- class ContrastWithRatio (BaseModel , validate_assignment = True , extra = 'forbid' ):
91+ class ContrastWithRatio (RATModel , validate_assignment = True , extra = 'forbid' ):
8092 """Groups together all of the components of the model including domain terms."""
8193 name : str = Field (default_factory = lambda : 'New Contrast ' + next (contrast_number ), min_length = 1 )
8294 data : str = ''
@@ -90,20 +102,20 @@ class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
90102 model : list [str ] = []
91103
92104
93- class CustomFile (BaseModel , validate_assignment = True , extra = 'forbid' ):
105+ class CustomFile (RATModel , validate_assignment = True , extra = 'forbid' ):
94106 """Defines the files containing functions to run when using custom models."""
95107 name : str = Field (default_factory = lambda : 'New Custom File ' + next (custom_file_number ), min_length = 1 )
96108 filename : str = ''
97109 language : Languages = Languages .Python
98110 path : str = 'pwd' # Should later expand to find current file path
99111
100112
101- class Data (BaseModel , validate_assignment = True , extra = 'forbid' , arbitrary_types_allowed = True ):
113+ class Data (RATModel , validate_assignment = True , extra = 'forbid' , arbitrary_types_allowed = True ):
102114 """Defines the dataset required for each contrast."""
103115 name : str = Field (default_factory = lambda : 'New Data ' + next (data_number ), min_length = 1 )
104- data : np .ndarray [float ] = np .empty ([0 , 3 ])
105- data_range : list [float ] = []
106- simulation_range : list [float ] = [ 0.005 , 0.7 ]
116+ data : np .ndarray [np . float64 ] = np .empty ([0 , 3 ])
117+ data_range : list [float ] = Field ( default = [], min_length = 2 , max_length = 2 )
118+ simulation_range : list [float ] = Field ( default = [], min_length = 2 , max_length = 2 )
107119
108120 @field_validator ('data' )
109121 @classmethod
@@ -120,22 +132,79 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:
120132
121133 @field_validator ('data_range' , 'simulation_range' )
122134 @classmethod
123- def check_list_elements (cls , limits : list [float ], info : ValidationInfo ) -> list [float ]:
124- """The data range and simulation range must contain exactly two parameters ."""
125- if len ( limits ) != 2 :
126- raise ValueError (f'{ info .field_name } must contain exactly two values ' )
135+ def check_min_max (cls , limits : list [float ], info : ValidationInfo ) -> list [float ]:
136+ """The data range and simulation range maximum must be greater than the minimum ."""
137+ if limits [ 0 ] > limits [ 1 ] :
138+ raise ValueError (f'{ info .field_name } "min" value is greater than the "max" value ' )
127139 return limits
128140
129- # Also need model validators for data range compared to data etc -- need more details.
141+ def model_post_init (self , __context : Any ) -> None :
142+ """If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
143+ set to the min and max values of the first column (assumed to be q) of the supplied data.
144+ """
145+ if len (self .data [:, 0 ]) > 0 :
146+ data_min = np .min (self .data [:, 0 ])
147+ data_max = np .max (self .data [:, 0 ])
148+ for field in ["data_range" , "simulation_range" ]:
149+ if field not in self .model_fields_set :
150+ getattr (self , field ).extend ([data_min , data_max ])
151+
152+ @model_validator (mode = 'after' )
153+ def check_ranges (self ) -> 'Data' :
154+ """The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
155+ of the "simulation_range" field must lie outside of the range of the supplied data.
156+ """
157+ if len (self .data [:, 0 ]) > 0 :
158+ data_min = np .min (self .data [:, 0 ])
159+ data_max = np .max (self .data [:, 0 ])
160+ if "data_range" in self .model_fields_set and (self .data_range [0 ] < data_min or
161+ self .data_range [1 ] > data_max ):
162+ raise ValueError (f'The data_range value of: { self .data_range } must lie within the min/max values of '
163+ f'the data: [{ data_min } , { data_max } ]' )
164+ if "simulation_range" in self .model_fields_set and (self .simulation_range [0 ] > data_min or
165+ self .simulation_range [1 ] < data_max ):
166+ raise ValueError (f'The simulation_range value of: { self .simulation_range } must lie outside of the '
167+ f'min/max values of the data: [{ data_min } , { data_max } ]' )
168+ return self
169+
170+ def __eq__ (self , other : Any ) -> bool :
171+ if isinstance (other , BaseModel ):
172+ # When comparing instances of generic types for equality, as long as all field values are equal,
173+ # only require their generic origin types to be equal, rather than exact type equality.
174+ # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
175+ self_type = self .__pydantic_generic_metadata__ ['origin' ] or self .__class__
176+ other_type = other .__pydantic_generic_metadata__ ['origin' ] or other .__class__
177+
178+ return (
179+ self_type == other_type
180+ and self .name == other .name
181+ and (self .data == other .data ).all ()
182+ and self .data_range == other .data_range
183+ and self .simulation_range == other .simulation_range
184+ and self .__pydantic_private__ == other .__pydantic_private__
185+ and self .__pydantic_extra__ == other .__pydantic_extra__
186+ )
187+ else :
188+ return NotImplemented # delegate to the other item in the comparison
189+
190+ def __repr__ (self ):
191+ """Only include the name if the data is empty."""
192+ fields_repr = (f"name={ self .name !r} " if self .data .size == 0 else
193+ ", " .join (repr (v ) if a is None else
194+ f"{ a } ={ v !r} "
195+ for a , v in self .__repr_args__ ()
196+ )
197+ )
198+ return f'{ self .__repr_name__ ()} ({ fields_repr } )'
130199
131200
132- class DomainContrast (BaseModel , validate_assignment = True , extra = 'forbid' ):
201+ class DomainContrast (RATModel , validate_assignment = True , extra = 'forbid' ):
133202 """Groups together the layers required for each domain."""
134203 name : str = Field (default_factory = lambda : 'New Domain Contrast ' + next (domain_contrast_number ), min_length = 1 )
135204 model : list [str ] = []
136205
137206
138- class Layer (BaseModel , validate_assignment = True , extra = 'forbid' , populate_by_name = True ):
207+ class Layer (RATModel , validate_assignment = True , extra = 'forbid' , populate_by_name = True ):
139208 """Combines parameters into defined layers."""
140209 name : str = Field (default_factory = lambda : 'New Layer ' + next (layer_number ), min_length = 1 )
141210 thickness : str = ''
@@ -145,7 +214,7 @@ class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_nam
145214 hydrate_with : Hydration = Hydration .BulkOut
146215
147216
148- class AbsorptionLayer (BaseModel , validate_assignment = True , extra = 'forbid' , populate_by_name = True ):
217+ class AbsorptionLayer (RATModel , validate_assignment = True , extra = 'forbid' , populate_by_name = True ):
149218 """Combines parameters into defined layers including absorption terms."""
150219 name : str = Field (default_factory = lambda : 'New Layer ' + next (layer_number ), min_length = 1 )
151220 thickness : str = ''
@@ -156,7 +225,7 @@ class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', popul
156225 hydrate_with : Hydration = Hydration .BulkOut
157226
158227
159- class Parameter (BaseModel , validate_assignment = True , extra = 'forbid' ):
228+ class Parameter (RATModel , validate_assignment = True , extra = 'forbid' ):
160229 """Defines parameters needed to specify the model."""
161230 name : str = Field (default_factory = lambda : 'New Parameter ' + next (parameter_number ), min_length = 1 )
162231 min : float = 0.0
@@ -180,7 +249,7 @@ class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
180249 name : str = Field (frozen = True , min_length = 1 )
181250
182251
183- class Resolution (BaseModel , validate_assignment = True , extra = 'forbid' ):
252+ class Resolution (RATModel , validate_assignment = True , extra = 'forbid' ):
184253 """Defines Resolutions in RAT."""
185254 name : str = Field (default_factory = lambda : 'New Resolution ' + next (resolution_number ), min_length = 1 )
186255 type : Types = Types .Constant
0 commit comments