TP refactor for FSDP + TP integration#45028
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
fcea5ce to
f98e208
Compare
- DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests
607cc11 to
739332c
Compare
- Replace hook-based TP with DTensor-based TPStyle API - TPStyle dataclass with dense kinds: colwise, rowwise, vocab - apply_tensor_parallel() using PyTorch parallelize_module - verify_tp_plan() for plan validation - Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle - DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3 - Extended DistributedConfig with tp/fsdp size and plan fields - DistributedConfig serialization in configuration_utils - MXFP4 NotImplementedError for DTensor TP - Dense TP tests
1aa7f5f to
11b55a2
Compare
dbc9619 to
c567240
Compare
34a5085 to
eb428cc
Compare
c567240 to
c1dab9e
Compare
eb428cc to
e0c4e06
Compare
* MoE expert parallelism + sequence parallelism - Add PackedColwiseParallel for fused gate_up_proj weights - Add MoEExpertsParallel with per-expert DTensor sharding - Add PrepareModuleInputOutput for SP allgather/split hooks - Add _AllReduceBackward for MoE routing weight gradients - Extend TPStyle with moe_experts, packed_colwise, activation, module kinds - _StridedShard handling in core_model_loading for interleaved weights - MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans - DTensor rotary_pos_emb guard for mixtral * Fix ruff linting and formatting * Fix ruff formatting in core_model_loading.py * Restore _IdentityOp accidentally removed in 25a1f48 The _IdentityOp class (added by PR #44983) was accidentally deleted during the MoE expert parallelism work. It is needed by finegrained_fp8.py and metal_quantization.py as a pass-through reverse_op for dequantize operations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Backport new TP/FSDP API + fix DTensor imports in Copied-from models * from_pretrained orchestration + distributed save/load (#45409) * from_pretrained orchestration + save/load - Add gather_full_state_dict() for DTensor→full tensor saving - Add convert_strided_to_shard() / restore_strided_from_shard() for DCP - Add _redistribute_dtensor() helper - Full distributed_config integration in from_pretrained/save_pretrained - Rename apply_fsdp2 → apply_fully_shard_data_parallel - save_optimizer() / load_optimizer() in distributed/utils - Trainer integration with distributed_config - Updated FSDP and TP tests for new orchestration API - DTensor shard-on-read test updates * revert distributed utils * eaaea * all tests for core modeling are passing * populate import from init for tp * ruff * ruff --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| return | ||
|
|
||
| # Filter out module-level comm hooks — they don't shard weights | ||
| _NON_WEIGHT_KINDS = {"activation", "module"} |
There was a problem hiding this comment.
maybe separate tp and sp style ?
Restores modeling files to their base branch versions so the PR diff only shows the distributed/patches.py monkey-patch approach instead of noisy function moves in modeling files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
"colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter
| return model | ||
|
|
||
| if isinstance(fsdp_plan, str): | ||
| if fsdp_plan == "auto": |
There was a problem hiding this comment.
define fsdp_plan in every model and remove auto code path (will be done in the next PR)
These string plan values have no TPStyle equivalent in the DTensor system. Remove them to avoid TypeError at apply_tensor_parallel time. Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.
| # loss_parallel patches F.cross_entropy to work with Shard(-1) logits. | ||
| # It must be active during both forward and backward, so we enable it | ||
| # once rather than as a context manager. | ||
| has_loss_parallel = any(isinstance(v, TPStyle) and v.comm == "loss_parallel" for v in tp_plan.values()) |
There was a problem hiding this comment.
| def _inject_sp_metadata(mod, args, kwargs): | ||
| input_ids = kwargs.get("input_ids", args[0] if args else None) | ||
| if input_ids is None: | ||
| return args, kwargs | ||
| if "position_ids" not in kwargs or kwargs["position_ids"] is None: | ||
| seq_len = input_ids.shape[1] | ||
| kwargs["position_ids"] = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) | ||
| return args, kwargs |
There was a problem hiding this comment.
ParallelStyle with just prepare inputs
There was a problem hiding this comment.
SequenceParallelInput(Parallel) -> explicit I need to put this and it exists
There was a problem hiding this comment.
sure but that would mean specifying a empty string to the base_model_sp_plan since it's not a module level
and more like at a base model
base_model_sp_plan = {
"" : SequenceParallelInput(),
"embed_tokens": TPStyle("vocab", "reduce_scatter"),
"layers.*.input_layernorm": TPStyle("activation", "none"),
"layers.*.self_attn": TPStyle("module", "allgather", input_key="hidden_states"),
"layers.*.self_attn.q_proj": TPStyle("colwise", "none"),
| import sys | ||
| from functools import wraps | ||
|
|
||
| from torch.distributed.tensor import DTensor, Replicate |
There was a problem hiding this comment.
let's discuss: https://github.com/huggingface/transformers/blob/7f49ecc51cacc6b9b60151ebb6a32e66eb71d163/src/transformers/models/llama/modeling_llama.py#L135
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
| def __init__(self, output_layouts=None): |
There was a problem hiding this comment.
check
and TP (on expert dim) check that fsdp when slicing expert ids does gather hidden_dim since that's what its expected to do
|
|
||
| def is_dtensor_like(value: Any) -> bool: | ||
| return all(hasattr(value, attr) for attr in ("device_mesh", "placements", "to_local")) | ||
| class DtensorShardOperation: |
There was a problem hiding this comment.
drop the shard on read approach. We should load rely on distribute_tensor Dtensor api, this will avoid us to do any slicing because of shard on read
…rsions (e.g. _StridedShard↔Shard). We force replicate beforehand
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet, cohere |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45028&sha=f710f0 |
verify_all_loss->trainingwith saving + loading back for generate ?)Verify loading
Training