Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions autotest/dfns/test_migrate_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from modflow_devtools.dfns.schema import (
Double,
FieldBase,
File,
Integer,
Keyword,
List,
Model,
Expand Down Expand Up @@ -497,3 +499,80 @@ def test_model_no_dependent_variable():
result = v1_to_v2(_v1_dfn(name="prt-nam"))
assert isinstance(result, Model)
assert result.dependent_variable is None


def test_nested_record():
"""Nested record subfields (e.g. gwf-oc formatrecord) are recursively mapped."""
dfn = _v1_dfn(
name="gwf-oc",
blocks={
"options": {
"headprintrecord": _v1_field(
name="headprintrecord",
type="record head print_format formatrecord",
optional=True,
),
"head": _v1_field(name="head", type="keyword", in_record=True),
"print_format": _v1_field(name="print_format", type="keyword", in_record=True),
"formatrecord": _v1_field(
name="formatrecord",
type="record columns width digits format",
in_record=True,
),
"columns": _v1_field(
name="columns", type="integer", in_record=True, optional=True, tagged=True
),
"width": _v1_field(
name="width", type="integer", in_record=True, optional=True, tagged=True
),
"digits": _v1_field(
name="digits", type="integer", in_record=True, optional=True, tagged=True
),
"format": _v1_field(
name="format", type="string", in_record=True, optional=False, tagged=False
),
}
},
)
component = v1_to_v2(dfn)
headprint = component.blocks["options"].fields["headprintrecord"]
assert isinstance(headprint, Record)
formatrecord = headprint.fields["formatrecord"]
assert isinstance(formatrecord, Record)
assert isinstance(formatrecord.fields["columns"], Integer)
assert isinstance(formatrecord.fields["width"], Integer)
assert isinstance(formatrecord.fields["digits"], Integer)
assert isinstance(formatrecord.fields["format"], String)


def test_prt_fmi_packagedata():
"""prt-fmi packagedata recarray is replaced with three named optional File fields."""
dfn = _v1_dfn(
name="prt-fmi",
blocks={
"packagedata": {
"packagedata": _v1_field(
name="packagedata",
type="recarray flowtype filein fname",
block="packagedata",
),
"flowtype": _v1_field(
name="flowtype", type="string", block="packagedata", in_record=True
),
"filein": _v1_field(
name="filein", type="keyword", block="packagedata", in_record=True
),
"fname": _v1_field(
name="fname", type="string", block="packagedata", in_record=True
),
}
},
)
component = v1_to_v2(dfn)
fields = component.blocks["packagedata"].fields
assert set(fields) == {"gwfhead", "gwfbudget", "gwfgrid"}
for fname in ("gwfhead", "gwfbudget", "gwfgrid"):
f = fields[fname]
assert isinstance(f, File)
assert f.optional is True
assert f.mode == "filein"
34 changes: 26 additions & 8 deletions modflow_devtools/dfns/migrate_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,27 @@ def _patch(fields: dict) -> tuple[dict, bool]:
return result


def _fix_prt_fmi(component: v2.Component) -> v2.Component:
"""
Replace prt-fmi's heterogeneous packagedata recarray with three named
optional File fields — one per flow type (GWFHEAD, GWFBUDGET, GWFSPDIS).
"""
block = (component.blocks or {}).get("packagedata")
if block is None:
return component
new_fields = {
name: v2.File(name=name, longname=longname, optional=True, tagged=True, mode="filein")
for name, longname in (
("gwfhead", "gwf head file"),
("gwfbudget", "gwf budget file"),
("gwfgrid", "gwf grid file"),
)
}
new_blocks = dict(component.blocks or {})
new_blocks["packagedata"] = block.model_copy(update={"fields": new_fields})
return component.model_copy(update={"blocks": new_blocks})


def v1_to_v2(dfn: v1.Dfn) -> v2.Component:
"""Map a component definition from the v1 schema to v2."""

Expand Down Expand Up @@ -635,9 +656,7 @@ def _record_fields() -> dict:
matches = [
fi
for fi in fields.values(multi=True)
if fi["name"] == rname
and fi.get("in_record", False)
and not (fi["type"] or "").startswith("record")
if fi["name"] == rname and fi.get("in_record", False)
]
if matches:
result[rname] = __map_field(matches[0])
Expand Down Expand Up @@ -855,8 +874,7 @@ def _record_fields() -> dict:
else:
is_stress_pkg = bool(any(blocks) and any("period" in k for k in blocks))
subtype = "advanced" if dfn["advanced"] else "stress" if is_stress_pkg else None
return v2.Package(
**d,
subtype=subtype,
multi=dfn["multi"],
)
pkg = v2.Package(**d, subtype=subtype, multi=dfn["multi"])
if name == "prt-fmi":
return _fix_prt_fmi(pkg)
return pkg
Loading