Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle
f" controls procedure are:\n "
f"{', '.join(fields.get('procedure', []))}\n",
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None

if isinstance(model_input, validated_self.__class__):
Expand Down
121 changes: 89 additions & 32 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ def update_renamed_models(self) -> "Project":
for index, param in all_matches:
if param in params:
setattr(project_field[index], param, new_name)
self._all_names = self.get_all_names()
return self

@model_validator(mode="after")
Expand All @@ -566,28 +565,45 @@ def cross_check_model_values(self) -> "Project":
values = ["value_1", "value_2", "value_3", "value_4", "value_5"]
for field in ["backgrounds", "resolutions"]:
self.check_allowed_source(field)
self.check_allowed_values(field, values, getattr(self, f"{field[:-1]}_parameters").get_names())
self.check_allowed_values(
field,
values,
getattr(self, f"{field[:-1]}_parameters").get_names(),
self._all_names[f"{field[:-1]}_parameters"],
)

self.check_allowed_values(
"layers",
["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness"],
self.parameters.get_names(),
self._all_names["parameters"],
)

self.check_allowed_values("contrasts", ["data"], self.data.get_names())
self.check_allowed_values("contrasts", ["background"], self.backgrounds.get_names())
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names())
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names())
self.check_allowed_values("contrasts", ["scalefactor"], self.scalefactors.get_names())
self.check_allowed_values("contrasts", ["resolution"], self.resolutions.get_names())
self.check_allowed_values("contrasts", ["domain_ratio"], self.domain_ratios.get_names())
self.check_allowed_values("contrasts", ["data"], self.data.get_names(), self._all_names["data"])
self.check_allowed_values(
"contrasts", ["background"], self.backgrounds.get_names(), self._all_names["backgrounds"]
)
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names(), self._all_names["bulk_in"])
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names(), self._all_names["bulk_out"])
self.check_allowed_values(
"contrasts", ["scalefactor"], self.scalefactors.get_names(), self._all_names["scalefactors"]
)
self.check_allowed_values(
"contrasts", ["resolution"], self.resolutions.get_names(), self._all_names["resolutions"]
)
self.check_allowed_values(
"contrasts", ["domain_ratio"], self.domain_ratios.get_names(), self._all_names["domain_ratios"]
)

self.check_contrast_model_allowed_values(
"contrasts",
getattr(self, self._contrast_model_field).get_names(),
self._all_names[self._contrast_model_field],
self._contrast_model_field,
)
self.check_contrast_model_allowed_values("domain_contrasts", self.layers.get_names(), "layers")
self.check_contrast_model_allowed_values(
"domain_contrasts", self.layers.get_names(), self._all_names["layers"], "layers"
)
return self

@model_validator(mode="after")
Expand All @@ -606,6 +622,12 @@ def check_protected_parameters(self) -> "Project":
self._protected_parameters = self.get_all_protected_parameters()
return self

@model_validator(mode="after")
def update_names(self) -> "Project":
"""Following validation, update the list of all parameter names."""
self._all_names = self.get_all_names()
return self

def __str__(self):
output = ""
for key, value in self.__dict__.items():
Expand All @@ -630,7 +652,9 @@ def get_all_protected_parameters(self):
for class_list in parameter_class_lists
}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
def check_allowed_values(
self, attribute: str, field_list: list[str], allowed_values: list[str], previous_values: list[str]
) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Parameters
Expand All @@ -641,6 +665,8 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
The fields of the attribute to be checked for valid values.
allowed_values : list [str]
The list of allowed values for the fields given in field_list.
previous_values : list [str]
The list of allowed values for the fields given in field_list after the previous validation.

Raises
------
Expand All @@ -649,14 +675,23 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va

"""
class_list = getattr(self, attribute)
for model in class_list:
for index, model in enumerate(class_list):
for field in field_list:
value = getattr(model, field, "")
if value and value not in allowed_values:
raise ValueError(
f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
f'"{values_defined_in[f"{attribute}.{field}"]}".',
)
if value in previous_values:
raise ValueError(
f'The value "{value}" used in the "{field}" field at index {index} of "{attribute}" '
f'must be defined in "{values_defined_in[f"{attribute}.{field}"]}". Please remove '
f'"{value}" from "{attribute}{index}.{field}" before attempting to delete it.',
)
else:
raise ValueError(
f'The value "{value}" used in the "{field}" field at index {index} of "{attribute}" '
f'must be defined in "{values_defined_in[f"{attribute}.{field}"]}". Please add '
f'"{value}" to "{values_defined_in[f"{attribute}.{field}"]}" before including it in '
f'"{attribute}".',
)

def check_allowed_source(self, attribute: str) -> None:
"""Check that the source of a background or resolution is defined in the relevant field for its type.
Expand All @@ -679,24 +714,37 @@ def check_allowed_source(self, attribute: str) -> None:

"""
class_list = getattr(self, attribute)
for model in class_list:
for index, model in enumerate(class_list):
if model.type == TypeOptions.Constant:
allowed_values = getattr(self, f"{attribute[:-1]}_parameters").get_names()
previous_values = self._all_names[f"{attribute[:-1]}_parameters"]
elif model.type == TypeOptions.Data:
allowed_values = self.data.get_names()
previous_values = self._all_names["data"]
else:
allowed_values = self.custom_files.get_names()
previous_values = self._all_names["custom_files"]

if (value := model.source) != "" and value not in allowed_values:
raise ValueError(
f'The value "{value}" in the "source" field of "{attribute}" must be defined in '
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}".',
)
if value in previous_values:
raise ValueError(
f'The value "{value}" used in the "source" field at index {index} of "{attribute}" '
f'must be defined in "{values_defined_in[f"{attribute}.{model.type}.source"]}". Please remove '
f'"{value}" from "{attribute}{index}.source" before attempting to delete it.',
)
else:
raise ValueError(
f'The value "{value}" used in the "source" field at index {index} of "{attribute}" '
f'must be defined in "{values_defined_in[f"{attribute}.{model.type}.source"]}". Please add '
f'"{value}" to "{values_defined_in[f"{attribute}.{model.type}.source"]}" before including it '
f'in "{attribute}".',
)

def check_contrast_model_allowed_values(
self,
contrast_attribute: str,
allowed_values: list[str],
previous_values: list[str],
allowed_field: str,
) -> None:
"""Ensure the contents of the ``model`` for a contrast or domain contrast exist in the required project fields.
Expand All @@ -707,6 +755,8 @@ def check_contrast_model_allowed_values(
The specific contrast attribute of Project being validated (either "contrasts" or "domain_contrasts").
allowed_values : list [str]
The list of allowed values for the model of the contrast_attribute.
previous_values : list [str]
The list of allowed values for the model of the contrast_attribute after the previous validation.
allowed_field : str
The name of the field in the project in which the allowed_values are defined.

Expand All @@ -717,13 +767,21 @@ def check_contrast_model_allowed_values(

"""
class_list = getattr(self, contrast_attribute)
for contrast in class_list:
model_values = contrast.model
if model_values and not all(value in allowed_values for value in model_values):
raise ValueError(
f'The values: "{", ".join(str(i) for i in model_values)}" in the "model" field of '
f'"{contrast_attribute}" must be defined in "{allowed_field}".',
)
for index, contrast in enumerate(class_list):
if (model_values := contrast.model) and not all(value in allowed_values for value in model_values):
if all(value in previous_values for value in model_values):
raise ValueError(
f'The values: "{", ".join(str(i) for i in model_values)}" used in the "model" field at index '
f'{index} of "{contrast_attribute}" must be defined in "{allowed_field}". Please remove '
f'all unnecessary values from "model" before attempting to delete them.',
)
else:
raise ValueError(
f'The values: "{", ".join(str(i) for i in model_values)}" used in the "model" field at index '
f'{index} of "{contrast_attribute}" must be defined in "{allowed_field}". Please add '
f'all required values to "{allowed_field}" '
f'before including them in "{contrast_attribute}".',
)

def get_contrast_model_field(self):
"""Get the field used to define the contents of the "model" field in contrasts.
Expand Down Expand Up @@ -945,7 +1003,7 @@ def wrapped_func(*args, **kwargs):
Project.model_validate(self)
except ValidationError as exc:
class_list.data = previous_state
custom_error_list = custom_pydantic_validation_error(exc.errors())
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False))
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
except (TypeError, ValueError):
class_list.data = previous_state
Expand Down Expand Up @@ -980,9 +1038,8 @@ def try_relative_to(path: Path, relative_to: Path) -> str:
else:
warnings.warn(
"Could not save custom file path as relative to the project directory, "
"which means that it may not work on other devices."
"If you would like to share your project, make sure your custom files "
"are in a subfolder of the project save location.",
"which means that it may not work on other devices. If you would like to share your project, "
"make sure your custom files are in a subfolder of the project save location.",
stacklevel=2,
)
return str(path.resolve())
Loading