Skip to content

Add Feature Universal Checkpoint for AutoTP#7908

Open
nathon-lee wants to merge 15 commits intodeepspeedai:masterfrom
nathon-lee:feat_uc_autotp
Open

Add Feature Universal Checkpoint for AutoTP#7908
nathon-lee wants to merge 15 commits intodeepspeedai:masterfrom
nathon-lee:feat_uc_autotp

Conversation

@nathon-lee
Copy link
Contributor

Hi DeepSpeed team — thanks for your time reviewing this PR.

Summary

Add Universal Checkpoint (UC) metadata support for DeepSpeed AutoTP to enable saving and resuming from Universal Checkpoints.

Motivation

AutoTP partitions parameters across TP ranks. To make checkpoints portable and restorable, we need a stable UC metadata representation that can be collected at save time and consumed at restore time.

What’s in this PR

  • Collect AutoTP-specific Universal Checkpoint metadata for TP-partitioned parameters.
  • Provide restore/merge helpers that normalize shapes and correctly interpret the saved conversion/partition view.
  • Keep existing (non-AutoTP / non-UC) checkpoint paths unchanged (no behavior change expected for other users).

Testing

  • pytest -q tests/unit/runtime/tensor_parallel/test_autotp_universal_checkpoint.py
  • pytest -q tests/unit/checkpoint/test_autotp_universal_checkpoint.py

Request for feedback

Could you please take a look at the UC metadata schema and let me know if you’d prefer any changes to naming, field placement, or compatibility expectations? I’m happy to iterate quickly based on your guidance.

References

Copilot AI and others added 10 commits February 27, 2026 06:30
This reverts commit ff88670.

Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com>
Revert "fix: update 1 file reformatted." (ff88670)
Revert accidental Muon optimizer code re-introduction from copilot PRs
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5ab684d950

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +462 to +464
conversion_meta = _get_param_uc_conversion_meta(param)
if not conversion_meta:
continue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Include sharded embeddings in UC metadata collection

AutoTP shards nn.Embedding weights in _slice_embedding() along hidden dimension (dim=1), but this collector skips any parameter without ds_autotp_universal_checkpoint_meta, so embedding parameters never enter UNIVERSAL_CHECKPOINT_INFO; during conversion they fall back to default merge behavior (cat_dim=0 in ds_to_universal.merge_tp_slices), which reconstructs embeddings along the wrong axis and breaks restores when TP degree changes.

Useful? React with 👍 / 👎.

Comment on lines +732 to +737
self._set_param_uc_meta(self.weight,
partition_type='column',
partition_dim=0,
logical_shape=original_weight_shape,
output_shape=(original_out_dim, ),
original_shape=original_weight_shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mark fused-QKV layers with sub-parameter UC metadata

fused_LinearLayer partitions weights with prepare_tp_fused_qkvw() (model-specific Q/K/V reordering), but it inherits this generic LinearLayer metadata path, which records only plain column partitioning and no sub-parameter schema; converter logic then treats these tensors as simple concat-able shards, producing interleaved QKV layouts that are not portable across different TP sizes.

Useful? React with 👍 / 👎.

@nathon-lee nathon-lee changed the title Add Feature Universal Checkpoint autotp Add Feature Universal Checkpoint for AutoTP Mar 17, 2026
Copy link

@PawnOfDelock PawnOfDelock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Review for PR #7908: AutoTP Universal Checkpoint Support

✅ Overall Impression

This is a well-structured and well-documented PR that adds critical functionality for AutoTP universal checkpoint support. The implementation is clean, tests are comprehensive, and CI passes.


🔍 Code Review

1. Design & Architecture 🎯

  • ✅ Good separation of concerns between restore-time and conversion-time metadata
  • ✅ The dual-view metadata design (top-level for restore, nested conversion for model-level aggregation) is elegant
  • ✅ Backward compatibility is properly preserved (non-AutoTP paths unchanged)

2. Implementation Details 💻

universal_checkpoint.py:

  • _resolve_autotp_partition handles various partition scenarios correctly (row/column, sub-params, replicated)
  • ✅ Clean integration with existing load_hp_checkpoint_state - minimal changes to existing logic
  • ✅ Proper shape normalization and error handling

layers.py:

  • ✅ Consistent _mark_uc_metadata implementation across all TP layer types
  • collect_autotp_universal_checkpoint_info properly aggregates parameter-level metadata into model-level schema
  • ✅ Regex pattern generation is clean and correct

Optimizers (bf16_optimizer.py, stage_1_and_2.py):

  • _enable_universal_checkpoint properly caches UC info from parameters
  • ✅ State dict integration is consistent across both optimizer types

engine.py:

  • ✅ Properly collects and attaches UC info to model after AutoTP partitioning
  • ✅ Checkpoint saving includes UC info

3. Testing 🧪

  • ✅ Test coverage is excellent - both unit and integration tests
  • ✅ Tests cover key scenarios:
    • Row/column parallel weights
    • Subparam partitioning
    • Replicated biases
    • Metadata aggregation
    • Optimizer state handling
  • ✅ Mocking is well-implemented

4. Documentation 📖

  • ✅ Function docstrings are clear and helpful
  • ✅ Code comments explain non-obvious logic
  • ✅ PR description is comprehensive with motivation, implementation details, and testing instructions

🔮 Suggestions & Questions

  1. Edge Cases:

    • Consider adding a test for empty sub_param_sizes scenario in _resolve_autotp_partition
    • Could add a test for mismatched logical_shape vs output_shape expectations
  2. Performance:

    • The metadata collection happens at partition time (good)
    • Consider if collect_autotp_universal_checkpoint_info should be cached if called multiple times
  3. Error Handling:

    • What happens if AUTOTP_UC_META_KEY exists but is malformed? Could add validation
  4. Future Compatibility:

    • The schema design is flexible for future additions - good job

⚠️ Minor Nitpicks

  1. In universal_checkpoint.py, function signature:
def _resolve_autotp_partition(self, ckpt_dict, full_hp_param, tp_rank, tp_world_size):
  • The self parameter is passed but _resolve_autotp_partition is not a method - consider making it clearer
  1. In layers.py, could regex pattern generation be a utility function to avoid repetition?

These are very minor and don't block approval.


✅ CI Status

  • ✅ All CI checks pass
  • ✅ Tests run successfully
  • ✅ DCO check passes

🎯 Recommendation

LGTM! 🚀

This PR is ready for merge. The implementation is solid, tests are comprehensive, and it properly integrates with existing checkpoint infrastructure.

Suggested follow-ups (post-merge):

  1. Add e2e integration tests with actual model training/saving/loading
  2. Document the UC metadata schema in user-facing docs
  3. Consider adding a migration guide for existing AutoTP checkpoints

from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)


AUTOTP_UC_META_KEY = 'ds_autotp_universal_checkpoint_meta'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the only place the reference the string 'ds_autotp_universal_checkpoint_meta'. The other place is in layers.py with a different name DS_AUTOTP_UC_META. The name is not consistent. (DS_AUTOTP_UC_META looks better). If the string needs to be the same all over the place, should have a single definition (i.e. in this file) then import from other places (i.e. from deepspeed.checkpoint.universal_checkpoint.py import DS_AUTOTP_UC_META). There shouldn't be a second place with the content of the string, including tests.

return getattr(param, AUTOTP_UC_META_KEY, None)


def _resolve_autotp_partition(self, ckpt_dict, full_hp_param, tp_rank, tp_world_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter name self is not a proper name for a none class member function (This is also true for load_hp_checkpoint_state which is already there). If self mean current parameter, can we changes its name to current_param or a proper name? Thanks!

is_bias = meta.get('is_bias', False)
replicated = meta.get('replicated', False)

if replicated or partition_dim is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a scenario that one of them is True and another one is False? Probably better to have the following instead:

if replicated:
    assert(partition_dim is None)
    ...

if replicated or partition_dim is None:
slice_tensor = full_hp_param
else:
target_shape = output_shape if is_bias else logical_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line implies that output_shape would be different from logical_shape when is_bias is True. Will this really happen?

if target_shape is None:
return None

full_view = full_hp_param.view(target_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full_view is not used in this if statement, better move it down below before the next if statement.

return None

full_view = full_hp_param.view(target_shape)
if sub_param_sizes is None and sub_param_shape is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a case that sub_param_sizes cannot be deduced from sub_param_shape, and must save sub_param_size but not sub_param_shape when saving the meta data? I feel one of them is redudant.


for module_name, module in model.named_modules():
marker = getattr(module, '_mark_uc_metadata', None)
if callable(marker):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to check whether marker is callable? I didn't see this possibility unless its an implementation error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. _mark_uc_metadata is intended to be a callable hook; using callable() could hide an implementation error. I’ll remove the callable() guard and simply check for None so incorrect types fail fast.

@@ -0,0 +1,241 @@
import types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The implementation looks good overall. Could you please add some documentation and usage descriptions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @inkcherry! Sure — I’ll add documentation and usage descriptions and update the PR shortly.

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee
Copy link
Contributor Author

Hi Delock, thanks for the suggestion. I’ve centralized the AutoTP universal-checkpoint metadata attribute name into deepspeed/checkpoint/constants.py as DS_AUTOTP_UC_META, and updated both universal_checkpoint.py and module_inject/layers.py to import and use it (so we no longer hardcode the string in multiple places).

Please let me know if you’d prefer I split the other _resolve_autotp_partition changes into a separate PR to keep this one focused.
@delock

Signed-off-by: nathon-lee <leejianwoo@gmail.com>


def _write_tp_states(base_dir, param_name, tp_idx, fp32_tensor):
# merge_tp_slices 会尝试合并这三个 state,所以测试必须把它们都写出来
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not matter but I would suggest to keep comments in plain English for consistency.

partition_dim=0,
name="packed")

weight_meta = getattr(layer.weight, "ds_autotp_universal_checkpoint_meta")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to keep these string imported rather than hard coded.

chunks = [sub.chunk(2, dim=0)[0] for sub in full_hp_param.view(3, 2, 4)]
expected = torch.cat(chunks, dim=0).flatten()
assert torch.equal(slice_flat, expected)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test when size of each subparam are not equal to each other? This could cover the case of GQA

  sub_param_sizes 描述的是多个子参数在拼接维度上各自的大小不等的情况。

  典型场景是 QKV fused linear,比如一个模型的 Q/K/V head 数量不同:
  Q: 32 heads × 128 dim = 4096
  K: 8 heads × 128 dim = 1024
  V: 8 heads × 128 dim = 1024
  拼在一起后 sub_param_shape = (4096, 1024, 1024),总共 6144。

  做 TP 切分时,如果简单地把 6144 均分成 N 份,会切到 Q/K/V 的边界中间。正确做法是按每个子参数各自均分:
  TP=2 时:
    Q: 4096/2 = 2048
    K: 1024/2 = 512
    V: 1024/2 = 512
    每个 rank 拿到 3072

  sub_param_sizes 就是告诉 restore 逻辑每个子参数有多大,这样才能按子参数分别切而不是整体均分。

  在 _resolve_autotp_partition 中,当 sub_param_sizes 不为 None 时,代码会对每个子参数单独做 chunk/split,然后按 tp_rank 取对应的片段,再 concat 回来。而当 sub_param_sizes 为 None
  时(所有子参数等大,或者根本没有子参数),直接整体 chunk 就行。

  所以测试中缺的 case 就是:子参数大小不等时,restore 是否正确按各自大小切分。现有测试 2 的 sub_param_shape = ((2,2,2), 4) 三个子参数等大(都是 2),没有覆盖不等大的情况。

@delock
Copy link
Collaborator

delock commented Mar 18, 2026

Hi @nathon-lee , thanks for the timely response! No need to split.

I have given my comments. Overall it LGTM. I agree with @inkcherry that documentation better be updated, specifically, universal-checkpointing.md, autotp-training.md, and model-checkpointing.rst.

Hi Delock, thanks for the suggestion. I’ve centralized the AutoTP universal-checkpoint metadata attribute name into deepspeed/checkpoint/constants.py as DS_AUTOTP_UC_META, and updated both universal_checkpoint.py and module_inject/layers.py to import and use it (so we no longer hardcode the string in multiple places).

Please let me know if you’d prefer I split the other _resolve_autotp_partition changes into a separate PR to keep this one focused. @delock

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: update some logic for _resolve_autotp_partition

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: update some logic for test_load_hp_checkpoint_state_prefers_autotp_metadata

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee
Copy link
Contributor Author

Hi @delock, thank you for the review and LGTM!

Totally agree — I’ll update universal-checkpointing.md, autotp-training.md, and model-checkpointing.rst accordingly and push the doc changes shortly.
@delock

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

docs: update universal checkpointing and AutoTP checkpoint docs

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants