Skip to content

Commit c580435

Browse files
authored
Add code to handle calculation options (#14)
* Adds new model "AbsorptionLayer" and model validator "set_absorption" to use the correct layer. * Changes "AbsorptionLayer" model to use same default counter as "Layer" model * Ensures models cannot have zero-character length names. * Refines "update_renamed_models" in "project.py" to enable focus on specific fields * Add "domain_contrasts", "domain_ratios" to "project.py and "ContrastWithRatio" to "models.py" for domains calculations * Adds tests for domain calculation models * Adds model validators to handle the "model" field of "contrasts" to "project.py" * Tidies up code * Addresses review comments
1 parent d7b1f2c commit c580435

File tree

4 files changed

+594
-96
lines changed

4 files changed

+594
-96
lines changed

RAT/models.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The models module. Contains the pydantic models used by RAT to store project parameters."""
22

33
import numpy as np
4-
from pydantic import BaseModel, Field, FieldValidationInfo, field_validator, model_validator
4+
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
55

66
try:
77
from enum import StrEnum
@@ -54,7 +54,7 @@ class Types(StrEnum):
5454

5555
class Background(BaseModel, validate_assignment=True, extra='forbid'):
5656
"""Defines the Backgrounds in RAT."""
57-
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number))
57+
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
5858
type: Types = Types.Constant
5959
value_1: str = ''
6060
value_2: str = ''
@@ -65,28 +65,42 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):
6565

6666
class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
6767
"""Groups together all of the components of the model."""
68-
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number))
68+
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
6969
data: str = ''
7070
background: str = ''
7171
nba: str = ''
7272
nbs: str = ''
7373
scalefactor: str = ''
7474
resolution: str = ''
7575
resample: bool = False
76-
model: list[str] = [] # But how many strings? How to deal with this?
76+
model: list[str] = []
77+
78+
79+
class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
80+
"""Groups together all of the components of the model including domain terms."""
81+
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
82+
data: str = ''
83+
background: str = ''
84+
nba: str = ''
85+
nbs: str = ''
86+
scalefactor: str = ''
87+
resolution: str = ''
88+
resample: bool = False
89+
domain_ratio: str = ''
90+
model: list[str] = []
7791

7892

7993
class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
8094
"""Defines the files containing functions to run when using custom models."""
81-
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number))
95+
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number), min_length=1)
8296
filename: str = ''
8397
language: Languages = Languages.Python
8498
path: str = 'pwd' # Should later expand to find current file path
8599

86100

87101
class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
88102
"""Defines the dataset required for each contrast."""
89-
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number))
103+
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number), min_length=1)
90104
data: np.ndarray[float] = np.empty([0, 3])
91105
data_range: list[float] = []
92106
simulation_range: list[float] = [0.005, 0.7]
@@ -106,7 +120,7 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:
106120

107121
@field_validator('data_range', 'simulation_range')
108122
@classmethod
109-
def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) -> list[float]:
123+
def check_list_elements(cls, limits: list[float], info: ValidationInfo) -> list[float]:
110124
"""The data range and simulation range must contain exactly two parameters."""
111125
if len(limits) != 2:
112126
raise ValueError(f'{info.field_name} must contain exactly two values')
@@ -117,23 +131,34 @@ def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) ->
117131

118132
class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
119133
"""Groups together the layers required for each domain."""
120-
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number))
134+
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number), min_length=1)
121135
model: list[str] = []
122136

123137

124-
class Layer(BaseModel, validate_assignment=True, extra='forbid'):
138+
class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
125139
"""Combines parameters into defined layers."""
126-
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number))
140+
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
141+
thickness: str = ''
142+
SLD: str = Field('', validation_alias='SLD_real')
143+
roughness: str = ''
144+
hydration: str = ''
145+
hydrate_with: Hydration = Hydration.BulkOut
146+
147+
148+
class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
149+
"""Combines parameters into defined layers including absorption terms."""
150+
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
127151
thickness: str = ''
128-
SLD: str = ''
152+
SLD_real: str = Field('', validation_alias='SLD')
153+
SLD_imaginary: str = ''
129154
roughness: str = ''
130155
hydration: str = ''
131156
hydrate_with: Hydration = Hydration.BulkOut
132157

133158

134159
class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
135-
"""Defines parameters needed to specify the model"""
136-
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number))
160+
"""Defines parameters needed to specify the model."""
161+
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number), min_length=1)
137162
min: float = 0.0
138163
value: float = 0.0
139164
max: float = 0.0
@@ -152,12 +177,12 @@ def check_value_in_range(self) -> 'Parameter':
152177

153178
class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
154179
"""A Parameter with a fixed name."""
155-
name: str = Field(frozen=True)
180+
name: str = Field(frozen=True, min_length=1)
156181

157182

158183
class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
159184
"""Defines Resolutions in RAT."""
160-
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number))
185+
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
161186
type: Types = Types.Constant
162187
value_1: str = ''
163188
value_2: str = ''

0 commit comments

Comments
 (0)