Skip to content

Commit 427e289

Browse files
committed
added model renaming for data and function backgrounds/resolutions
1 parent caa40e2 commit 427e289

File tree

2 files changed

+53
-38
lines changed

2 files changed

+53
-38
lines changed

RATapi/project.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,25 @@ def discriminate_contrasts(contrast_input):
8080

8181
AllFields = collections.namedtuple("AllFields", ["attribute", "fields"])
8282
model_names_used_in = {
83-
"background_parameters": AllFields(
84-
"backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]
85-
),
86-
"resolution_parameters": AllFields(
87-
"resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]
88-
),
89-
"parameters": AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"]),
90-
"data": AllFields("contrasts", ["data"]),
91-
"backgrounds": AllFields("contrasts", ["background"]),
92-
"bulk_in": AllFields("contrasts", ["bulk_in"]),
93-
"bulk_out": AllFields("contrasts", ["bulk_out"]),
94-
"scalefactors": AllFields("contrasts", ["scalefactor"]),
95-
"domain_ratios": AllFields("contrasts", ["domain_ratio"]),
96-
"resolutions": AllFields("contrasts", ["resolution"]),
83+
"background_parameters": [
84+
AllFields("backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"])
85+
],
86+
"resolution_parameters": [
87+
AllFields("resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"])
88+
],
89+
"parameters": [AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"])],
90+
"data": [
91+
AllFields("contrasts", ["data"]),
92+
AllFields("backgrounds", ["source"]),
93+
AllFields("resolutions", ["source"]),
94+
],
95+
"custom_files": [AllFields("backgrounds", ["source"]), AllFields("resolutions", ["source"])],
96+
"backgrounds": [AllFields("contrasts", ["background"])],
97+
"bulk_in": [AllFields("contrasts", ["bulk_in"])],
98+
"bulk_out": [AllFields("contrasts", ["bulk_out"])],
99+
"scalefactors": [AllFields("contrasts", ["scalefactor"])],
100+
"domain_ratios": [AllFields("contrasts", ["domain_ratio"])],
101+
"resolutions": [AllFields("contrasts", ["resolution"])],
97102
}
98103

99104
# Note that the order of these parameters is hard-coded into RAT
@@ -508,18 +513,19 @@ def set_absorption(self) -> "Project":
508513
@model_validator(mode="after")
509514
def update_renamed_models(self) -> "Project":
510515
"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
511-
for class_list in model_names_used_in:
516+
for class_list, fields_to_update in model_names_used_in.items():
512517
old_names = self._all_names[class_list]
513518
new_names = getattr(self, class_list).get_names()
514519
if len(old_names) == len(new_names):
515520
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
516521
for old_name, new_name in name_diff:
517-
model_names_list = getattr(self, model_names_used_in[class_list].attribute)
518-
all_matches = model_names_list.get_all_matches(old_name)
519-
fields = model_names_used_in[class_list].fields
520-
for index, field in all_matches:
521-
if field in fields:
522-
setattr(model_names_list[index], field, new_name)
522+
for field in fields_to_update:
523+
project_field = getattr(self, field.attribute)
524+
all_matches = project_field.get_all_matches(old_name)
525+
params = field.fields
526+
for index, param in all_matches:
527+
if param in params:
528+
setattr(project_field[index], param, new_name)
523529
self._all_names = self.get_all_names()
524530
return self
525531

tests/test_project.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -625,26 +625,35 @@ def test_check_protected_parameters(delete_operation) -> None:
625625

626626

627627
@pytest.mark.parametrize(
628-
["model", "field"],
628+
["model", "fields"],
629629
[
630-
("background_parameters", "source"),
631-
("resolution_parameters", "source"),
632-
("parameters", "roughness"),
633-
("data", "data"),
634-
("backgrounds", "background"),
635-
("bulk_in", "bulk_in"),
636-
("bulk_out", "bulk_out"),
637-
("scalefactors", "scalefactor"),
638-
("resolutions", "resolution"),
630+
("background_parameters", ["source"]),
631+
("resolution_parameters", ["source"]),
632+
("parameters", ["roughness"]),
633+
("data", ["data", "source", "source"]),
634+
("custom_files", ["source", "source"]),
635+
("backgrounds", ["background"]),
636+
("bulk_in", ["bulk_in"]),
637+
("bulk_out", ["bulk_out"]),
638+
("scalefactors", ["scalefactor"]),
639+
("resolutions", ["resolution"]),
639640
],
640641
)
641-
def test_rename_models(test_project, model: str, field: str) -> None:
642+
def test_rename_models(test_project, model: str, fields: list[str]) -> None:
642643
"""When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere
643644
in the project.
644645
"""
646+
if model == "data":
647+
test_project.backgrounds[0] = RATapi.models.Background(type="data", source="Simulation")
648+
test_project.resolutions[0] = RATapi.models.Resolution(type="data", source="Simulation")
649+
if model == "custom_files":
650+
test_project.backgrounds[0] = RATapi.models.Background(type="function", source="Test Custom File")
651+
test_project.resolutions[0] = RATapi.models.Resolution(type="function", source="Test Custom File")
645652
getattr(test_project, model).set_fields(-1, name="New Name")
646-
attribute = RATapi.project.model_names_used_in[model].attribute
647-
assert getattr(getattr(test_project, attribute)[-1], field) == "New Name"
653+
model_name_lists = RATapi.project.model_names_used_in[model]
654+
for model_name_list, field in zip(model_name_lists, fields):
655+
attribute = model_name_list.attribute
656+
assert getattr(getattr(test_project, attribute)[-1], field) == "New Name"
648657

649658

650659
@pytest.mark.parametrize(
@@ -1197,7 +1206,7 @@ def test_wrap_del(test_project, class_list: str, parameter: str, field: str) ->
11971206
pydantic.ValidationError,
11981207
match=f"1 validation error for Project\n Value error, The value "
11991208
f'"{parameter}" in the "{field}" field of '
1200-
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
1209+
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
12011210
f'must be defined in "{class_list}".',
12021211
):
12031212
del test_attribute[index]
@@ -1405,7 +1414,7 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) ->
14051414
pydantic.ValidationError,
14061415
match=f"1 validation error for Project\n Value error, The value "
14071416
f'"{parameter}" in the "{field}" field of '
1408-
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
1417+
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
14091418
f'must be defined in "{class_list}".',
14101419
):
14111420
test_attribute.pop(index)
@@ -1437,7 +1446,7 @@ def test_wrap_remove(test_project, class_list: str, parameter: str, field: str)
14371446
pydantic.ValidationError,
14381447
match=f"1 validation error for Project\n Value error, The value "
14391448
f'"{parameter}" in the "{field}" field of '
1440-
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
1449+
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
14411450
f'must be defined in "{class_list}".',
14421451
):
14431452
test_attribute.remove(parameter)
@@ -1469,7 +1478,7 @@ def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) -
14691478
pydantic.ValidationError,
14701479
match=f"1 validation error for Project\n Value error, The value "
14711480
f'"{parameter}" in the "{field}" field of '
1472-
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
1481+
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
14731482
f'must be defined in "{class_list}".',
14741483
):
14751484
test_attribute.clear()

0 commit comments

Comments
 (0)