Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
reorder_noop = None


class DumbCheckpoint(object):
class DumbCheckpoint:

r"""A very dumb checkpoint object.

Expand Down Expand Up @@ -349,7 +349,7 @@ def __del__(self):
self.close()


class HDF5File(object):
class HDF5File:

r"""An object to facilitate checkpointing.

Expand Down Expand Up @@ -506,7 +506,7 @@ def __del__(self):
self.close()


class CheckpointFile(object):
class CheckpointFile:

r"""Checkpointing meshes and :class:`~.Function` s in an HDF5 file.

Expand All @@ -524,6 +524,18 @@ class CheckpointFile(object):
latest_version = '3.0.0'

def __init__(self, filename, mode, comm=COMM_WORLD):
# parse mode into a string
match mode:
case PETSc.Viewer.FileMode.READ | PETSc.Viewer.FileMode.R:
mode = "r"
case PETSc.Viewer.FileMode.WRITE | PETSc.Viewer.FileMode.W:
mode = "w"
case PETSc.Viewer.FileMode.APPEND | PETSc.Viewer.FileMode.A:
mode = "a"

if mode in {"r", "a"} and not os.path.exists(filename):
raise FileNotFoundError(f"'{filename}' does not exist")

self.viewer = ViewerHDF5()
self.filename = filename
self.comm = comm
Expand All @@ -534,21 +546,24 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
assert self.commkey != MPI.COMM_NULL.py2f()
self._function_spaces = {}
self._function_load_utils = {}
if mode in [PETSc.Viewer.FileMode.WRITE, PETSc.Viewer.FileMode.W, "w"]:
version = CheckpointFile.latest_version
self.set_attr_byte_string("/", "dmplex_storage_version", version)
elif mode in [PETSc.Viewer.FileMode.APPEND, PETSc.Viewer.FileMode.A, "a"]:
if self.has_attr("/", "dmplex_storage_version"):
version = self.get_attr_byte_string("/", "dmplex_storage_version")
else:

match mode:
case "w":
version = CheckpointFile.latest_version
self.set_attr_byte_string("/", "dmplex_storage_version", version)
elif mode in [PETSc.Viewer.FileMode.READ, PETSc.Viewer.FileMode.R, "r"]:
if not self.has_attr("/", "dmplex_storage_version"):
raise RuntimeError(f"Only files generated with CheckpointFile are supported: got an invalid file ({filename})")
version = CheckpointFile.latest_version
else:
raise NotImplementedError(f"Unsupportd file mode: {mode} not in {'w', 'a', 'r'}")
case "a":
if self.has_attr("/", "dmplex_storage_version"):
version = self.get_attr_byte_string("/", "dmplex_storage_version")
else:
version = CheckpointFile.latest_version
self.set_attr_byte_string("/", "dmplex_storage_version", version)
case "r":
if not self.has_attr("/", "dmplex_storage_version"):
raise RuntimeError(f"Only files generated with CheckpointFile are supported: got an invalid file ({filename})")
version = CheckpointFile.latest_version
case _:
raise NotImplementedError(f"Unsupported file mode: {mode} not in {'w', 'a', 'r'}")

self.opts = OptionsManager({"dm_plex_view_hdf5_storage_version": version}, "")
r"""DMPlex HDF5 version options."""

Expand Down
2 changes: 1 addition & 1 deletion pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def mult(self, mat, x, y):
else:
x.array_r
with mpi.temp_internal_comm(x.comm) as comm:
comm.bcast(a)
a = comm.bcast(a)
return y.scale(a)
else:
return v.pointwiseMult(x, y)
Expand Down
180 changes: 104 additions & 76 deletions tests/firedrake/output/test_io_backward_compat.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,19 @@
import pytest
from os.path import abspath, dirname, join, exists
from firedrake import *
from firedrake.mesh import make_mesh_from_coordinates
from firedrake.utils import IntType
import shutil
"""Backwards compatibility tests for CheckpointFile.

New save files for the current version of Firedrake can be generated by
running this file as a script.

test_version = "2024_01_27"
"""
import importlib.metadata
import os
from os.path import abspath, dirname, join, exists
import shutil

import pytest

"""
2024_01_27:
---------------------------------------------------------------------------
|Package |Branch |Revision |Modified |
---------------------------------------------------------------------------
|COFFEE |master |70c1e66 |False |
|FInAT |master |e2805c4 |False |
|PyOP2 |master |e0a4d3a9 |False |
|fiat |master |e7b2909 |False |
|firedrake |master |393f82f85 |False |
|h5py |firedrake |4c01efa9 |False |
|libspatialindex |master |4768bf3 |True |
|libsupermesh |master |84becef |False |
|loopy |main |8158afdb |False |
|petsc |firedrake |09f36907a6e|False |
|pyadjoint |master |2c6614d |False |
|pytest-mpi |main |a478bc8 |False |
|slepc |firedrake |a3f39c853 |False |
|tsfc |master |799191d |False |
|ufl |master |054b0617 |False |
---------------------------------------------------------------------------
"""
from firedrake import *
from firedrake.mesh import make_mesh_from_coordinates
from firedrake.utils import IntType, complex_mode


cwd = abspath(dirname(__file__))
Expand All @@ -42,6 +25,20 @@
stokes_control_mesh_file = join(cwd, "..", "..", "..", "docs", "notebooks/stokes-control.msh")


SAVED_VERSIONS = ["2024_01_27", "2025.10.3.dev0"]
"""Firedrake versions used to generate historic checkpoint files."""


@pytest.fixture(params=SAVED_VERSIONS)
def version(request):
return request.param


def _skip_if_missing(filename, version):
if not exists(filename):
pytest.skip(reason=f"Checkpoint file does not exist for version '{version}'")


def _initialise_function(f, _f):
f.project(_f, solver_parameters={"ksp_type": "cg", "pc_type": "jacobi", "ksp_rtol": 1.e-16})

Expand Down Expand Up @@ -119,7 +116,10 @@ def _get_mesh_and_V(params):
elem = TensorProductElement(helem, velem)
V = FunctionSpace(mesh, elem)
else:
raise NotImplementedError
assert cell_type == "quadrilateral"
base = UnitSquareMesh(10, 10, name=f"{mesh_name}_base")
mesh = ExtrudedMesh(base, layers=5, layer_height=1.0, name=mesh_name)
V = FunctionSpace(mesh, "P", 3)
elif periodic:
if cell_type == "triangle":
mesh = PeriodicUnitSquareMesh(20, 20, name=mesh_name)
Expand Down Expand Up @@ -175,6 +175,7 @@ def _get_expr(V):
("quadrilateral", False, False, False, False, False, False),
("hexahedron", False, False, False, False, False, False),
("triangle", False, True, False, False, False, False), # extruded (variable layer)
("quadrilateral", False, True, False, False, False, False), # extruded (constant layer)
("triangle", True, False, False, False, False, False), # periodic
("tetrahedron", True, False, False, False, False, False), # periodic
("interval", False, True, True, False, False, False), # extruded_periodic
Expand Down Expand Up @@ -209,28 +210,35 @@ def _test_io_backward_compat_base_idfunc(params):
return "-".join([f"{p_str}={p}" for p_str, p in zip(param_str, params)])


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('version', [test_version])
@pytest.mark.parametrize('params', test_io_backward_compat_base_params, ids=_test_io_backward_compat_base_idfunc)
@pytest.mark.skip(reason="Only run these tests to create test files.")
def test_io_backward_compat_base_save(version, params):
filename = join(filedir, "_".join([basename, version, _make_name(params) + ".h5"]))
if exists(filename):
raise RuntimeError(f"path {filename} already exists.")
mesh, V = _get_mesh_and_V(params)
f = Function(V, name=func_name)
_initialise_function(f, _get_expr(V))
with CheckpointFile(filename, "w") as afile:
afile.save_function(f)
def test_all_files_used():
unused_checkpoint_files = set(os.listdir(filedir))
assert len(unused_checkpoint_files) > 0, "Did not find any checkpoint files, cannot check if they are used"
for version in SAVED_VERSIONS:
for params in test_io_backward_compat_base_params:
filename = "_".join([basename, version, _make_name(params) + ".h5"])
unused_checkpoint_files.discard(filename)

# also the timestepping file
filename = "_".join([basename, version, "timestepping" + ".h5"])
unused_checkpoint_files.discard(filename)
assert not unused_checkpoint_files, f"Checkpoint files {unused_checkpoint_files} are not tested"


def test_version_has_corresponding_files(version):
for params in test_io_backward_compat_base_params:
filename = join(filedir, "_".join([basename, version, _make_name(params) + ".h5"]))
if exists(filename):
# at least one match, this is sufficient
return
raise AssertionError(f"Version '{version}' does not have any associated checkpoint files")


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=4)
@pytest.mark.parametrize('version', ["2024_01_27"])
@pytest.mark.parallel(4)
@pytest.mark.parametrize('params', test_io_backward_compat_base_params, ids=_test_io_backward_compat_base_idfunc)
def test_io_backward_compat_base_load(version, params):
filename = join(filedir, "_".join([basename, version, _make_name(params) + ".h5"]))
_skip_if_missing(filename, version)
with CheckpointFile(filename, "r") as afile:
mesh = afile.load_mesh(mesh_name)
f = afile.load_function(mesh, func_name)
Expand All @@ -251,30 +259,10 @@ def _get_expr_timestepping(V, i):


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('version', [test_version])
@pytest.mark.skip(reason="Only run these tests to create test files.")
def test_io_backward_compat_timestepping_save(version):
filename = join(filedir, "_".join([basename, version, "timestepping" + ".h5"]))
if exists(filename):
raise RuntimeError(f"path {filename} already exists.")
mesh = UnitSquareMesh(8, 8, name=mesh_name)
BDM = FunctionSpace(mesh, "BDM", 1)
DG = FunctionSpace(mesh, "DG", 0)
R = FunctionSpace(mesh, "Real", 0)
V = BDM * DG * R
f = Function(V, name=func_name)
with CheckpointFile(filename, 'w') as afile:
for i in range(5):
_initialise_function(f, _get_expr_timestepping(V, i))
afile.save_function(f, idx=i)


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=4)
@pytest.mark.parametrize('version', ["2024_01_27"])
@pytest.mark.parallel(4)
def test_io_backward_compat_timestepping_load(version):
filename = join(filedir, "_".join([basename, version, "timestepping" + ".h5"]))
_skip_if_missing(filename, version)
with CheckpointFile(filename, "r") as afile:
mesh = afile.load_mesh(mesh_name)
for i in range(5):
Expand All @@ -286,31 +274,71 @@ def test_io_backward_compat_timestepping_load(version):


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('version', ["2024_01_27"])
@pytest.mark.parallel(3)
def test_io_backward_compat_timestepping_append(version, tmpdir):
filename = join(filedir, "_".join([basename, version, "timestepping" + ".h5"]))
_skip_if_missing(filename, version)
copyname = join(str(tmpdir), "test_io_backward_compat_timestepping_append_dump.h5")
copyname = COMM_WORLD.bcast(copyname, root=0)
shutil.copyfile(filename, copyname)
with CheckpointFile(copyname, "r") as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == CheckpointFile.latest_version
mesh = afile.load_mesh(mesh_name)
f = afile.load_function(mesh, func_name, idx=0)
V = f.function_space()
with CheckpointFile(copyname, 'a') as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == '2.1.0'
for i in range(5, 10):
_initialise_function(f, _get_expr_timestepping(V, i))
afile.save_function(f, idx=i)
with CheckpointFile(copyname, "r") as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == CheckpointFile.latest_version
for i in range(0, 10):
f = afile.load_function(mesh, func_name, idx=i)
V_ = f.function_space()
f_ = Function(V_)
_initialise_function(f_, _get_expr_timestepping(V_, i))
assert assemble(inner(f - f_, f - f_) * dx) < 1.e-16


if __name__ == "__main__":
if complex_mode:
print("Backwards compat checkpoint files should be generated in real mode")
exit(1)

version = importlib.metadata.version("firedrake") # noqa: F811
if version in SAVED_VERSIONS:
print(f"Version {version} already has historic checkpoint file data saved")
exit(1)
print(f"Saving checkpoint files for Firedrake version {version}")

print("Saving files for 'test_io_backward_compat_base_load'")
for params in test_io_backward_compat_base_params:
test_id = _test_io_backward_compat_base_idfunc(params)
print(f"Saving function for '{test_id}'")
filename = join(filedir, "_".join([basename, version, _make_name(params) + ".h5"]))
if exists(filename):
raise RuntimeError(f"path {filename} already exists.")
mesh, V = _get_mesh_and_V(params)
f = Function(V, name=func_name)
_initialise_function(f, _get_expr(V))
with CheckpointFile(filename, "w") as afile:
afile.save_function(f)
print(f"Function for '{test_id}' saved")
print("Files for 'test_io_backward_compat_base_load' saved")

print("Saving files for 'test_io_backward_compat_timestepping_load'")
filename = join(filedir, "_".join([basename, version, "timestepping" + ".h5"]))
if exists(filename):
raise RuntimeError(f"path {filename} already exists.")
mesh = UnitSquareMesh(8, 8, name=mesh_name)
BDM = FunctionSpace(mesh, "BDM", 1)
DG = FunctionSpace(mesh, "DG", 0)
R = FunctionSpace(mesh, "Real", 0)
V = BDM * DG * R
f = Function(V, name=func_name)
with CheckpointFile(filename, 'w') as afile:
for i in range(5):
_initialise_function(f, _get_expr_timestepping(V, i))
afile.save_function(f, idx=i)
print("Files for 'test_io_backward_compat_timestepping_load' saved")

print("All checkpoint files saved, please add them to the Firedrake repository, "
f"making sure to include the new version '{version}' in 'SAVED_VERSIONS'")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading