Add async checkpoint feature#1703
Conversation
7a7136b to
302b6ec
Compare
| from xtuner.v1.utils.grad_norm import cal_grad_norm | ||
|
|
||
|
|
||
| if BlockingAsyncStager is not None: |
There was a problem hiding this comment.
In [2]: fw = FileSystemWriter("./")
In [3]: from torch.distributed.checkpoint.staging import AsyncStager, BlockingAsyncStager
In [4]: isinstance(fw, AsyncStager)
Out[4]: True
is _CachingStagingWriter necessary?
| options=_set_options, | ||
| ) | ||
|
|
||
| def load_dcp_merged( |
There was a problem hiding this comment.
The state dict format should be consistant with async_save and save. If merged_state_dict performs better, just replace the current implementation.
| self._async_checkpoint = async_checkpoint | ||
| self._pending_staging_futures: list[Future] | None = None | ||
| self._pending_upload_futures: list[Future] | None = None | ||
| self._pending_checkpoint_finalize: _CheckpointFinalize | None = None |
There was a problem hiding this comment.
Following dcp.async_save, the async interface should return an awaitable future. We can assume there is at most one in-flight async save future in the trainer at any time, and the trainer will always wait for the previous async save to finish before issuing a new one.
| ckpt_saved = self._maybe_save(is_snapshot=False) | ||
| if not ckpt_saved: | ||
| _ = self._maybe_save(is_snapshot=True) | ||
| checkpoint_time = time.time() - time_before_checkpoint |
There was a problem hiding this comment.
Just log the checkpoint time in train_engine
695d2b3 to
b6701ef
Compare
| self._async_checkpoint_pg: dist.ProcessGroup | None = None | ||
| self._async_state_dict_cache: dict[str, Any] | None = None | ||
| if async_checkpoint: | ||
| self._async_checkpoint_pg = dist.new_group(backend="gloo") |
There was a problem hiding this comment.
Please leave a comment to describe why we need a gloo process group here.
| if not hasattr(dcp, "async_save"): | ||
| raise RuntimeError( | ||
| "dcp.async_save is not available in this PyTorch version. " | ||
| "Please upgrade PyTorch or set async_checkpoint=False." | ||
| ) |
| cached_has_optim = "optimizer" in self._async_state_dict_cache | ||
| if cached_has_optim != save_optimizer: | ||
| self._async_state_dict_cache = None |
There was a problem hiding this comment.
when will this branch be triggered?
| if cached_has_optim != save_optimizer: | ||
| self._async_state_dict_cache = None | ||
| storage_writer = FileSystemWriter(weights_dir, cache_staged_state_dict=True) | ||
| storage_writer.state_dict_cache = self._async_state_dict_cache |
There was a problem hiding this comment.
Is this injection necessary?
There was a problem hiding this comment.
cache_staged_state_dict keeps pinned staging buffers on the FileSystemWriter instance. XTuner creates one writer per checkpoint path, so carry the cache across writers to preserve steady-state async_save launch performance.
| def destroy_async_checkpoint_pg(self) -> None: | ||
| """Destroy the dedicated gloo process group used for async checkpoint.""" | ||
| self._async_state_dict_cache = None | ||
| if self._async_checkpoint_pg is not None: | ||
| dist.destroy_process_group(self._async_checkpoint_pg) | ||
| self._async_checkpoint_pg = None |
| future = self._engine.async_save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer) | ||
| t_dcp = time.time() - t_dcp | ||
| # Defer metadata save until async save completes. | ||
| self._pending_checkpoint = _PendingCheckpoint( |
There was a problem hiding this comment.
Trainer shouldn't need to know about _CheckpointFinalize. Instead, you can call Future.add_done_callback in TrainEngine so that the future tracks timing correctly. The trainer just needs to wait for it (or await it).
b6701ef to
b8c953d
Compare
| cur_epoch = self._cur_epoch | ||
| train_time_offset = self._train_time + self._train_time_offset | ||
|
|
||
| def finalize_checkpoint_metadata() -> None: |
There was a problem hiding this comment.
Considering keeping the original implementation, I think even if dcp hasn't finished saving, it should be fine to save the meta information first. Try to avoid increasing code complexity just to introduce the asynchronous save feature.
| dcp_label = "async_save_dcp" | ||
| future = self._engine.async_save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer) | ||
| t_dcp = time.time() - t_dcp |
There was a problem hiding this comment.
The time of asynchronous saving should be collected and printed by the engine
b8c953d to
1eb91e5
Compare
1eb91e5 to
962cc16
Compare
Add async DCP checkpoint support
This change adds async checkpoint saving for XTuner v1 training. The trainer
now supports an async_checkpoint option, starts merged async DCP saves for model
and optimizer state, and defers checkpoint metadata finalization until the
background staging/upload futures complete.
The async path writes model and optimizer state into a merged weights/
checkpoint format, while resume keeps compatibility with both the new merged
format and the existing model/optimizer DCP format. Checkpoint metadata is only
registered after async save completion, so failed async saves are not exposed as
resumable checkpoints.
The training engine now creates a dedicated process group for async checkpoint
work, supports merged async save/load helpers, and cleans up the async process
group at trainer shutdown.
Tests and benchmark configs are added to cover async checkpoint intervals and
provide reproducible verification runs for 8B and 30B models.