Skip to content

Add async checkpoint feature#1703

Open
VincentCheungKokomo wants to merge 1 commit intoInternLM:mainfrom
VincentCheungKokomo:feature/async-checkpoint
Open

Add async checkpoint feature#1703
VincentCheungKokomo wants to merge 1 commit intoInternLM:mainfrom
VincentCheungKokomo:feature/async-checkpoint

Conversation

@VincentCheungKokomo
Copy link
Copy Markdown

Add async DCP checkpoint support

This change adds async checkpoint saving for XTuner v1 training. The trainer
now supports an async_checkpoint option, starts merged async DCP saves for model
and optimizer state, and defers checkpoint metadata finalization until the
background staging/upload futures complete.

The async path writes model and optimizer state into a merged weights/
checkpoint format, while resume keeps compatibility with both the new merged
format and the existing model/optimizer DCP format. Checkpoint metadata is only
registered after async save completion, so failed async saves are not exposed as
resumable checkpoints.

The training engine now creates a dedicated process group for async checkpoint
work, supports merged async save/load helpers, and cleans up the async process
group at trainer shutdown.

Tests and benchmark configs are added to cover async checkpoint intervals and
provide reproducible verification runs for 8B and 30B models.

@VincentCheungKokomo VincentCheungKokomo force-pushed the feature/async-checkpoint branch 2 times, most recently from 7a7136b to 302b6ec Compare April 23, 2026 03:47
Comment thread xtuner/v1/engine/train_engine.py Outdated
from xtuner.v1.utils.grad_norm import cal_grad_norm


if BlockingAsyncStager is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [2]: fw = FileSystemWriter("./")

In [3]: from torch.distributed.checkpoint.staging import AsyncStager, BlockingAsyncStager

In [4]: isinstance(fw, AsyncStager)
Out[4]: True

is _CachingStagingWriter necessary?

Comment thread xtuner/v1/engine/train_engine.py Outdated
options=_set_options,
)

def load_dcp_merged(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state dict format should be consistant with async_save and save. If merged_state_dict performs better, just replace the current implementation.

Comment thread xtuner/v1/train/trainer.py Outdated
Comment on lines +540 to +543
self._async_checkpoint = async_checkpoint
self._pending_staging_futures: list[Future] | None = None
self._pending_upload_futures: list[Future] | None = None
self._pending_checkpoint_finalize: _CheckpointFinalize | None = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following dcp.async_save, the async interface should return an awaitable future. We can assume there is at most one in-flight async save future in the trainer at any time, and the trainer will always wait for the previous async save to finish before issuing a new one.

Comment thread xtuner/v1/train/trainer.py Outdated
ckpt_saved = self._maybe_save(is_snapshot=False)
if not ckpt_saved:
_ = self._maybe_save(is_snapshot=True)
checkpoint_time = time.time() - time_before_checkpoint
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just log the checkpoint time in train_engine

@VincentCheungKokomo VincentCheungKokomo force-pushed the feature/async-checkpoint branch 6 times, most recently from 695d2b3 to b6701ef Compare April 30, 2026 08:08
self._async_checkpoint_pg: dist.ProcessGroup | None = None
self._async_state_dict_cache: dict[str, Any] | None = None
if async_checkpoint:
self._async_checkpoint_pg = dist.new_group(backend="gloo")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a comment to describe why we need a gloo process group here.

Comment thread xtuner/v1/engine/train_engine.py Outdated
Comment on lines +347 to +351
if not hasattr(dcp, "async_save"):
raise RuntimeError(
"dcp.async_save is not available in this PyTorch version. "
"Please upgrade PyTorch or set async_checkpoint=False."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary check.

Comment thread xtuner/v1/engine/train_engine.py Outdated
Comment on lines +359 to +361
cached_has_optim = "optimizer" in self._async_state_dict_cache
if cached_has_optim != save_optimizer:
self._async_state_dict_cache = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will this branch be triggered?

Comment thread xtuner/v1/engine/train_engine.py Outdated
if cached_has_optim != save_optimizer:
self._async_state_dict_cache = None
storage_writer = FileSystemWriter(weights_dir, cache_staged_state_dict=True)
storage_writer.state_dict_cache = self._async_state_dict_cache
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this injection necessary?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_staged_state_dict keeps pinned staging buffers on the FileSystemWriter instance. XTuner creates one writer per checkpoint path, so carry the cache across writers to preserve steady-state async_save launch performance.

Comment on lines +376 to +381
def destroy_async_checkpoint_pg(self) -> None:
"""Destroy the dedicated gloo process group used for async checkpoint."""
self._async_state_dict_cache = None
if self._async_checkpoint_pg is not None:
dist.destroy_process_group(self._async_checkpoint_pg)
self._async_checkpoint_pg = None
Copy link
Copy Markdown
Collaborator

@HAOCHENYE HAOCHENYE May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it in __del__

Comment thread xtuner/v1/train/trainer.py Outdated
future = self._engine.async_save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer)
t_dcp = time.time() - t_dcp
# Defer metadata save until async save completes.
self._pending_checkpoint = _PendingCheckpoint(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trainer shouldn't need to know about _CheckpointFinalize. Instead, you can call Future.add_done_callback in TrainEngine so that the future tracks timing correctly. The trainer just needs to wait for it (or await it).

@VincentCheungKokomo VincentCheungKokomo force-pushed the feature/async-checkpoint branch from b6701ef to b8c953d Compare May 7, 2026 09:39
Comment thread xtuner/v1/train/trainer.py Outdated
cur_epoch = self._cur_epoch
train_time_offset = self._train_time + self._train_time_offset

def finalize_checkpoint_metadata() -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering keeping the original implementation, I think even if dcp hasn't finished saving, it should be fine to save the meta information first. Try to avoid increasing code complexity just to introduce the asynchronous save feature.

Comment thread xtuner/v1/train/trainer.py Outdated
Comment on lines +1245 to +1247
dcp_label = "async_save_dcp"
future = self._engine.async_save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer)
t_dcp = time.time() - t_dcp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The time of asynchronous saving should be collected and printed by the engine

@VincentCheungKokomo VincentCheungKokomo force-pushed the feature/async-checkpoint branch from b8c953d to 1eb91e5 Compare May 7, 2026 10:44
@VincentCheungKokomo VincentCheungKokomo force-pushed the feature/async-checkpoint branch from 1eb91e5 to 962cc16 Compare May 8, 2026 03:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants