Skip to content

Add per-rank disk checkpointing for adjoint tape#4891

Open
sghelichkhani wants to merge 20 commits intofiredrakeproject:mainfrom
sghelichkhani:sghelichkhani/per-rank-disk-checkpointing
Open

Add per-rank disk checkpointing for adjoint tape#4891
sghelichkhani wants to merge 20 commits intofiredrakeproject:mainfrom
sghelichkhani:sghelichkhani/per-rank-disk-checkpointing

Conversation

@sghelichkhani
Copy link
Contributor

@sghelichkhani sghelichkhani commented Feb 15, 2026

Motivation

We run time-dependent adjoint Stokes simulations with close to a billion degrees of freedom per timestep. Recomputation-based checkpointing schedules (revolve/binomial) are infeasible due to the cost of recomputing the Stokes solve, so disk checkpointing (SingleDiskStorageSchedule) is the only viable option.

Currently, CheckpointFile writes all ranks to a single shared HDF5 file via parallel HDF5 (PETSc.ViewerHDF5 on COMM_WORLD). On HPC systems, this means all checkpoint I/O goes through the shared parallel filesystem (Lustre/GPFS), which becomes a severe bottleneck. Under 24-hour job time limits, the disk I/O overhead makes simulations that comfortably fit in memory-checkpointed wall time infeasible when switching to disk checkpointing.

HPC nodes typically have fast node-local NVMe/SSD storage that is orders of magnitude faster than the shared filesystem. However, the current collective I/O approach in CheckpointFile cannot use node-local storage because all ranks must access the same file path.

Approach

Following @connorjward's suggestion in #4891 (comment), the implementation uses a general checkpoint_comm parameter rather than a hardcoded per-rank approach. Users pass any MPI communicator to control how function data is checkpointed:

# Per-rank files (each rank writes independently to node-local storage)
enable_disk_checkpointing(checkpoint_comm=MPI.COMM_SELF,
                          checkpoint_dir="/local/scratch")

# Per-node files (ranks on the same node share a file)
node_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
enable_disk_checkpointing(checkpoint_comm=node_comm,
                          checkpoint_dir="/local/scratch")

The function data is written using PETSc.Vec.createMPI + ViewerHDF5 on the supplied communicator, bypassing CheckpointFile and its collective globalVectorView/globalVectorLoad on COMM_WORLD. We tried using CheckpointFile directly with a sub-communicator (see #4891 (comment)), but loading deadlocks because the mesh DM's sectionLoad/globalVectorLoad are collective on COMM_WORLD.

The mesh checkpoint via checkpointable_mesh still uses shared storage through CheckpointFile since that's a one-time operation and not performance-critical. Fully backwards compatible: without checkpoint_comm, behaviour is unchanged.

Multi-mesh considerations

Functions on different meshes with different partitioning work correctly because Vec dataset names include the mesh name and element info (ckpt_mesh_a_CG2 vs ckpt_mesh_b_DG1), and checkpointable_mesh ensures deterministic partitioning per mesh independently.

The supermesh projection across two different meshes still fails in parallel, but that's a pre-existing limitation unrelated to this PR.

Testing

11 tests total covering three checkpointing modes:

Existing shared-mode tests (5): serial and parallel basic checkpointing, successive writes, timestepper with taylor_test, and boundary conditions.

checkpoint_comm with COMM_SELF (3): parallel basic checkpointing, successive writes (serial), and multi-mesh parallel. These exercise the per-rank file path where each rank writes independently.

checkpoint_comm with node communicator (3): parallel basic checkpointing, multi-mesh parallel, and timestepper with taylor_test. These exercise the multi-rank-per-file path using COMM_TYPE_SHARED.

Enable each MPI rank to write its adjoint checkpoint data to its own
HDF5 file using PETSc Vec I/O on COMM_SELF. This avoids parallel HDF5
overhead and enables use of fast node-local storage (NVMe/SSD) on HPC
systems, where shared filesystem I/O is a major bottleneck for
large-scale time-dependent adjoint computations.

New parameter `per_rank_dirname` on `enable_disk_checkpointing()`.
When set, function data is checkpointed per-rank while mesh data
(via `checkpointable_mesh`) remains on shared storage. Requires same
number of ranks on restore (inherent in adjoint workflows).
@connorjward
Copy link
Contributor

In PyOP2 we have the notion of a 'compilation comm' which is a communicator defined over each node (https://github.com/firedrakeproject/firedrake/blob/main/pyop2/mpi.py#L450). Might something like this be appropriate/more general here?

@sghelichkhani
Copy link
Contributor Author

Thanks Connor, that's a great idea. I hadn't considered using the compilation comm pattern here.

I did look into what a node-local comm approach would involve. The main challenge is that CheckpointFile.save_function uses topology_dm.globalVectorView, which is collective over the mesh's communicator (COMM_WORLD). So you can't simply pass a node-local sub-comm to CheckpointFile. The mesh DM lives on COMM_WORLD and you'd get a comm mismatch. That means even with a node-local comm, you'd still need to bypass CheckpointFile and use raw PETSc Vec I/O, just with a parallel Vec on the node comm instead of a sequential Vec on COMM_SELF. The overall complexity ends up being similar, with extra overhead for sub-communicator lifecycle management and intra-node coordination. The benefit is fewer files (N_nodes vs N_ranks).

From my perspective on Gadi, I have 48 cores per node writing to node-local SSDs, and per-rank I/O is completely manageable. My thinking here is that this is specifically an adjoint solution. Disk checkpointing for the adjoint tape is extremely I/O heavy, so the less communicator overhead involved, the better. I'd expect the COMM_SELF approach to actually be faster in practice since every rank operates completely independently with zero coordination or collective operations, even within a node.

That said, this is ultimately a decision for the Firedrake folks on what works best for general users. If a node-local comm approach is preferred, it's doable. Happy to refactor if that's the direction you'd like to go.

@sghelichkhani
Copy link
Contributor Author

@connorjward Following up on your suggestion about using the compilation communicator. Angus and I had a discussion about this and we tried to see if we can simply do CheckpointFile(fname, mode, comm=compilation_comm) and use the standard save_function/load_function path with a node-level communicator.

Saving works fine, but loading deadlocks. Here's a minimal reproducer:

"""mpiexec -n 4 python test_subcomm_checkpoint.py"""
import os
import tempfile
from firedrake import *
from firedrake.checkpointing import CheckpointFile

comm = COMM_WORLD
mesh = UnitSquareMesh(4, 4)
V = FunctionSpace(mesh, "CG", 1)
f = Function(V, name="f")
f.interpolate(SpatialCoordinate(mesh)[0])

node_comm = comm.Split(color=comm.rank // 2, key=comm.rank)

if comm.rank == 0:
    tmpdir = tempfile.mkdtemp()
else:
    tmpdir = None
tmpdir = comm.bcast(tmpdir, root=0)
fname = os.path.join(tmpdir, f"node{comm.rank // 2}.h5")

with CheckpointFile(fname, 'w', comm=node_comm) as out:
    out.save_mesh(mesh)
    out.save_function(f)

with CheckpointFile(fname, 'r', comm=node_comm) as inp:
    mesh2 = inp.load_mesh()
    f2 = inp.load_function(mesh2, "f")  # deadlocks here

The issue is that sectionLoad and globalVectorLoad are collective operations on the topology_dm, which lives on COMM_WORLD. When only a subset of ranks (the node sub-comm) enter the load call, the collective never completes. I think this is by design rather than a bug in CheckpointFile, since the contract is that the mesh lives on COMM_WORLD for the parallel solves, and the load path assumes the viewer and the mesh DM share a compatible communicator. Making this work with a sub-communicator would require changes both in PETSc (cross-comm DM operations) and in Firedrake's checkpointing internals.

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

@connorjward
Copy link
Contributor

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

I don't think so. I'd have been surprised had that worked.

That means even with a node-local comm, you'd still need to bypass CheckpointFile and use raw PETSc Vec I/O, just with a parallel Vec on the node comm instead of a sequential Vec on COMM_SELF.

This is what I'm suggesting. It doesn't seem like a lot of work to change your API from per_rank-type options to a more general filesystem_comm=some_comm type of thing. Your use case could still do filesystem_comm=COMM_SELF if desired.

@angus-g
Copy link
Contributor

angus-g commented Feb 17, 2026

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

I don't think so. I'd have been surprised had that worked.

Sorry, that was probably my misleading advice. My reasoning was just that it'd be nice to reduce complexity by leveraging the existing code for viewer set up etc. (and allow pyop2 to set up the comm in the first place).

It kind of seems like a CheckpointFile should be able to take a different comm and work, or at least it should be documented in which cases it can be something other than COMM_WORLD... I get that there are complexities around mesh topology so maybe it's a silly way of thinking in the first place.

Refactors the per_rank_dirname parameter into a more general
checkpoint_comm + checkpoint_dir interface, following reviewer
feedback. Instead of hardcoding COMM_SELF, users now pass any
MPI communicator (COMM_SELF for per-rank files, a node-local
comm for per-node files, etc.). The PETSc Vec I/O uses createMPI
on the supplied communicator rather than createSeq on COMM_SELF,
making the approach work for arbitrary communicator topologies.

Removes three serial-only checkpoint_comm tests that are fully
covered by their parallel counterparts and adds node_comm tests
that exercise the multi-rank-per-file path using COMM_TYPE_SHARED.
@sghelichkhani
Copy link
Contributor Author

Thanks Connor, done. I've refactored the API from per_rank_dirname to a general checkpoint_comm + checkpoint_dir interface. Users can now pass any communicator, so COMM_SELF for per-rank files or a node-local comm via COMM_TYPE_SHARED for per-node files.

The main PETSc-level change this required was switching from Vec.createSeq (which only works on a single process) to Vec.createMPI((local_size, PETSc.DECIDE), comm=checkpoint_comm). The DECIDE for the global size is important here because when checkpoint_comm groups multiple ranks (e.g. a node communicator), each rank contributes a different number of local DOFs depending on how the mesh was partitioned. The global size of the Vec on the checkpoint communicator is not something we know upfront since it's the sum of local DOFs across ranks in that sub-comm, which is a different grouping than the mesh's COMM_WORLD partitioning. Letting PETSc compute the global size from the local sizes avoids us having to gather that information ourselves.

I've also added tests with COMM_TYPE_SHARED to exercise the multi-rank-per-file path and trimmed redundant serial tests that were fully covered by their parallel counterparts.

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems alright to me. It would definitely be good to get some feedback from @JHopeCollins, who has done similar comm wrangling for ensemble.

@connorjward
Copy link
Contributor

It kind of seems like a CheckpointFile should be able to take a different comm and work, or at least it should be documented in which cases it can be something other than COMM_WORLD... I get that there are complexities around mesh topology so maybe it's a silly way of thinking in the first place.

It's definitely an interesting question. Conceptually I think it should be possible to checkpoint a DMPlex to multiple files but its far from trivial. An added complication is that we would have to preserve the N-to-M checkpointing behaviour (i.e. reading and writing with different numbers of ranks).

Add FutureWarning to deprecated new_checkpoint_file method. Use
isinstance(mesh, ufl.MeshSequence) instead of hasattr check. Replace
COMM_TYPE_SHARED tests with comm.Split(rank // 2) to guarantee a
communicator with 1 < size < COMM_WORLD in the 3-rank test.
Extract _generate_function_space_name from CheckpointFile into a
module-level function in firedrake/checkpointing.py and reuse it
for the checkpoint_comm Vec naming instead of maintaining a separate
_generate_checkpoint_vec_name. The free function also handles
MeshSequenceGeometry defensively. CheckpointFile method delegates
to it.
The multi-mesh tests chained two PDE solves via assemble(u_a * dx), a
global reduction whose floating-point result can vary across parallel
runs due to reduction ordering. This made the J == Jnew assertion flaky
at the np.allclose tolerance boundary. Make the mesh_b solve independent
and drop the redundant memory baseline comparison.
@sghelichkhani
Copy link
Contributor Author

sghelichkhani commented Feb 18, 2026

Simplified the multi-mesh tests to fix intermittent CI failures. The original design chained two solves via a global reduction (assemble(u_a * dx)), which amplified parallel floating-point non-determinism across tape replays. The two solves are now independent while still doing multi-mesh checkpointing.

Copy link
Member

@JHopeCollins JHopeCollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut feeling is that this should be part of CheckpointFile rather than hidden in the adjoint utils.

I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?

@connorjward
Copy link
Contributor

I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?

You can't split the DM into multiple files. That's why this only works for functions.

@JHopeCollins
Copy link
Member

JHopeCollins commented Feb 19, 2026

The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data.

I can think of four potential ways to do this. Happy to hear arguments for/against each one.

In each case we'd obviously need to be very explicit about the restrictions of saving/loading locally, i.e. you must have exactly the same partition for saving and viewing so its basically only for saving/loading during the same programme (and for option 3 you also can't save/load the mesh with this class).

  1. What about dm.localVectorView and dm.localVectorLoad? It looks like they will handle splitting the data apart which seem to do what we need "through the front door". I don't know if they can save the local data to local files though.
    https://petsc.org/release/manualpages/DMPlex/DMPlexLocalVectorView/
    https://petsc.org/release/manualpages/DMPlex/DMPlexLocalVectorLoad/

  2. In CheckpointFile, only use dm.globalVector{View,Load} if we are checkpointing to a global file, otherwise ignore those functions and save the Vec data to local files with your new logic. This is essentially the same as 2. but we add the switches in CheckpointFile.

  3. A new class, something like LocalFunctionCheckpointFile (happy to bikeshed this) that has the logic for saving/loading Functions to local files on subcomms. If we really don't trust users to use the local saving/loading properly then we could do this one and just don't advertise the new class in the public API.

  4. Writing a python-type viewer context for CheckpointFile.viewer that internally creates the current global HDF5 viewer and delegates to that one except if it's being asked to view/load a Vec, in which case it uses the new logic in the adjoint_utils for saving/viewing locally. So dm.globalVector{View,Load} would end up being diverted to the comm-local implementation.
    Hopefully this would just be implementing __init__, view and load methods in the python context (and maybe a blank setUp). But this option is probably trying to be too clever and overkill for what we need.

@connorjward
Copy link
Contributor

The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data.

This is a fair point.

What about dm.localVectorView and dm.localVectorLoad?

I've read the implementation and still confused about the difference between these. In PETSc the term 'local' applies to lots of different things so this may not do quite what you expect.

I wonder if this is basically reimplementing DumbCheckpoint again.

We should discuss this in today's meeting.

@connorjward
Copy link
Contributor

@sghelichkhani from today's meeting we decided that we want this functionality exposed as a new CheckpointFile-ish type living in checkpointing.py. This basically point 3 from @JHopeCollins above.

I quite like something like EphemeralFunctionCheckpointFile or similar. It would be cool if by default they could be destroyed at program exit to avoid misuse.

@JHopeCollins
Copy link
Member

JHopeCollins commented Feb 19, 2026

I quite like something like EphemeralFunctionCheckpointFile or similar.

How about TemporaryCheckpointFile, or TemporaryFunctionCheckpointFile, to mirror the tempfile naming and the common naming of a /tmp working directory.

It would be cool if by default they could be destroyed at program exit to avoid misuse.

@sghelichkhani is it possible to use tempfile.TemporaryFile and tempfile.TemporaryDirectory rather than tempfile.mkstemp and tempfile.mkdtemp so that the cleanup is automated?

@sghelichkhani
Copy link
Contributor Author

Note on _checkpoint_indices: this class-level dict on CheckpointFunction is never pruned. File entries persist after the HDF5 file is deleted. The local path (TemporaryFunctionCheckpointFile._indices) does not have this issue since remove_file handles cleanup. Not fixing this to avoid the risk of pruning entries before restore() reads them. The memory cost is negligible.

leo-collins and others added 5 commits February 21, 2026 22:37
…ject#4865)

* add log event markers

* build spatial index using CreateWithArray
Move PETSc Vec I/O into TemporaryFunctionCheckpointFile in checkpointing.py.
Rename save/restore methods, fix deprecation warning, remove redundant fixture
and forwarding method, clean up imports.
@sghelichkhani
Copy link
Contributor Author

All review comments from @connorjward and @JHopeCollins are addressed. The main change is extracting the PETSc Vec I/O into a TemporaryFunctionCheckpointFile class in checkpointing.py as we discussed. The adjoint module no longer imports petsc4py and just delegates to save_function/load_function. Directory cleanup uses tempfile.TemporaryDirectory so no atexit needed on that path.

The CheckpointFile._generate_function_space_name forwarding method is removed entirely, all call sites call the free function directly. The free function adds a MeshSequenceGeometry unwrap but this is a no-op on the CheckpointFile path since save_function already calls mesh.unique() upstream, it only matters for the new TemporaryFunctionCheckpointFile path. Methods renamed to _save_local_checkpoint/_restore_local_checkpoint, deprecation warning no longer references private API, dead checkpoint_comm attribute removed from CheckPointFileReference, and there is now a UserWarning when checkpoint_comm is used without checkpoint_dir.

Tests cleaned up: local set_test_tape fixture removed, serial test uses tmp_path, parallel tests use try/finally for cleanup, and taylor_test added to the multi-mesh test.

Two pre-existing items I looked at but decided not to change: _checkpoint_indices is never pruned (slow leak, but pruning risks removing entries before restore() reads them), and TemporaryDirectory cleanup depends on GC timing relative to MPI_Finalize (harmless in practice since node-local scratch is wiped between jobs).

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm now being very nitpicky and this is basically fine.

Copy link
Contributor Author

@sghelichkhani sghelichkhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied.

Remove redundant self.cleanup guard from TemporaryFunctionCheckpointFile.remove_file so the cleanup decision lives solely in CheckPointFileReference.__del__. Remove manual tape setup from the four new parallel tests since the autouse_test_taping fixture handles it. Clarify TemporaryFunctionCheckpointFile.comm docstring per Connor suggestion and document why _broadcast_tmpdir uses COMM_WORLD.
The multi-mesh tests have an independent solve on mesh_b that can give
slightly different results after repartitioning by checkpointable_mesh.
The taylor_test is the proper correctness check for the adjoint.
connorjward
connorjward previously approved these changes Feb 24, 2026
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this. It's a big change that involves code I am not super familiar with so it would be good to have an approving review from @JHopeCollins too.

@sghelichkhani
Copy link
Contributor Author

CI failures are unrelated: timeout in nprocs=6 I/O tests and a Gusto smoke test options issue. No files from this PR are involved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants