Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds Torch DCP (Distributed Checkpoint) compatibility for FSDP2 × TP strided sharding across all Key changes and observations:
Confidence Score: 4/5
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Outdated
Show resolved
Hide resolved
4ec2947 to
dbb9d14
Compare
fcdd5bd to
c912f5b
Compare
bc82f02 to
267f1df
Compare
|
/te-ci L1 pytorch |
f0b3cae to
af7362a
Compare
9435382 to
15df86f
Compare
|
/te-ci L1 pytorch |
|
For some reason after 2.3k training steps, I start to get NaNs: https://wandb.ai/nvidia/bionemo-recipes/runs/nmzugu0a?nw=nwusercye_nv Restarting from this checkpoint and around 500 steps later same thing. |
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
…ess. Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
82780a1 to
8ed5cc8
Compare
|
/te-ci L1 pytorch |
| module.reset_parameters() | ||
|
|
||
| # Run a training step to initialize FSDP2 lazy state and update quantization | ||
| # scales before testing the allgather. Block-scaling formats (Float8BlockScaling, |
There was a problem hiding this comment.
I believe Float8Blockscaling allgather should work now right?
| input_data = torch.randn(inp_shape, device=device) | ||
| target = torch.randn(inp_shape, device=device) | ||
| nvfp4_ctx = ( | ||
| torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Why seperate nvfp4 context? In general, adding multiple context manager adds CPU overheads in the training loop.
|
|
||
|
|
||
| @dataclass | ||
| class AppState(Stateful): |
There was a problem hiding this comment.
This seems like a useful class. With some things like extra state specific to TE. Might make sense to move it TE distributed module. Thoughts @cspades ?
| # TransformerEngine uses empty Tensors for dummy Parameters. | ||
| optimizer_state_dict["state"][fqn] = {} | ||
| if fqn.endswith("_extra_state"): | ||
| # Evict `_extra_state` quantization data from model checkpoint. |
There was a problem hiding this comment.
If this is evicted, how do we make sure it is updated correctly after load from checkpoint?
| ( | ||
| # FSDP | ||
| [NUM_PROCS], | ||
| # HSDP |
There was a problem hiding this comment.
Does this work ok if NUM_PROCS < 4? i.e lets say NUM_PROCS = 2. TP dimension will be 0. Curious what happens in that case?
|
|
||
| if tp_mesh is not None or weight_mesh is not None: | ||
| # Apply DeviceMesh and DTensor-related modifications. | ||
| self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) |
There was a problem hiding this comment.
in set_device_mesh function, it says weight_mesh is not necessary, but we call it only if both tp_mesh and weight_mesh is not None. So it should not include the condition weight_mesh is not none right?
| else device_mesh.get_group() | ||
| ) | ||
| quantizer.amax_reduction_group = amax_reduction_group | ||
| quantizer.amax_reduction_group = device_mesh.get_group() |
There was a problem hiding this comment.
Which group will it return in case of multiple dimensions? For instance if weight is both FSDP-TP sharded then will this give the FSDP dim or TP dim?
| if isinstance(weight, DTensor): | ||
| weight = weight.to_local() |
There was a problem hiding this comment.
Shouldnt we use _extract_trainable_tensor_from_dtensor here?
There was a problem hiding this comment.
Also applicable to couple in other places in ops folder
| instance._quantizer = quantizer.copy() if quantizer is not None else None | ||
| instance._fp8_dtype = fp8_dtype | ||
| instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales | ||
| instance._default_storage = torch.UntypedStorage(1, device=torch.cuda.current_device()) |
There was a problem hiding this comment.
I am leaning towards creating this default_storage on CPU instead due to coupe of reasons
- Since Idea of default_storage here is to show unique identity, keeping it on CPU/GPU shouldnt matter
- Calling torch cuda current device on every single Tensor creation has python overheads)
| integration with Torch DCP checkpointing. This method should only be invoked when | ||
| using DTensor parameters, e.g. when using FSDP2 or DCP. | ||
|
|
||
| When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically |
There was a problem hiding this comment.
Why is it that we havent added tp_mesh and weight_mesh to the constrictors of rmsnorm and layer_norm? But for every other layer we have?
| param = getattr(self, bias) | ||
| placements = (Replicate(),) | ||
| if self.parallel_mode == "column": | ||
| placements = (Shard(dim=0),) |
There was a problem hiding this comment.
I am wondering if we can make all the set_device_mesh function share a helper set_tp_mesh
defined in base.py that takes in a dictionary of {parameter name: parallel_mode} and tp_mesh that converts the parameters to Dtensors and use that in set_device_mesh of each module?
Something like this
def set_tp_mesh(self, param_mode_dict: dict, tp_mesh: Optional[DeviceMesh]):
There was a problem hiding this comment.
And put the big docstring that we have over there in base.py
. The docstring seems to be repeated in all places.
|
Generally LGTM @cspades. Lets get it merged after comments are addressed. |
Summary
(H/F)SDP2 x TPstrided sharding, andDTensorFP8 parameters for Torch DCP checkpointing, across allTransformerEngineBaseModule(s).GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules undertransformer_engine.pytorch.modulesare supported.FusibleOperationsupport is also a WIP, except forLayerNormorRMSNormwhich are TE modules.DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we useDTensor-based TP on thetorch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to thetorch.nn.Embedding, which is why we do not need to callset_device_meshfor the LM head!Usage / Documentation
(
tp_meshandweight_meshcan also be passed inTEModule.__init__.)Details
DTensor Lifecycle in TransformerEngine
__init__metadevice with the appropriatetp_sizeand TP sharding strategy, e.g.parallel_modeandsequence_parallel.TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)DTensorwith appropriate TPplacement(s) based on the TP sharding strategy specified in__init__, usingtransformer_engine.pytorch.distributed._convert_param_to_dtensor_param.tp_meshis a 1-DDeviceMeshcontaining the TPProcessGroupthat will be registered with the TransformerEngine module.weight_meshis the 1-DDeviceMeshcontaining theProcessGroupthat shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes likeFloat8CurrentScaling.fully_shard(which responds to the TP placements) and prior toreset_parameters(defer_init=False), which quantizes parameters.__init__(tp_mesh, weight_mesh)for supported TransformerEngine modules.fully_shardshards the TransformerEngine model with FSDP2.fully_shardencounters TP sharding ondim=0, it will use a_StridedShardfor DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in theDeviceMeshandDTensor.placements. (SeeAppendixfor visualization of this sharding strategy.)reset_parametersis called if using meta device initialization.fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such asFusedAdammust be used to properly handle high-precision main weights.)Tensoris actually a TP-shardedDTensor, which deviates from the original FSDP2 paradigm where the all-gatheredTensoris fully-unsharded and theDTensorwrapping is discarded. To support theseDTensorcompute weights in TransformerEngine modules, we utilizetransformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensorto localize theDTensorand also inheritrequires_gradattribute from theDTensorparameter as the localTensorhas this un-set duringDTensor.from_local(Tensor)for FP8 parameters specifically!Tensorgradient is converted toDTensorand attached to theDTensor.gradattribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case ofFusibleOperation, casted during the backward implementation.QuantizedTensorStorageNone, we senduntyped_storage()to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is adata_ptr() = nullptrand breaks DCP.untyped_storageis not used anywhere in TransformerEngine, it may break code that usesStorageto figure out if a Tensor is empty or not, as nowQuantizedTensorstorage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.Bugs
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_mesh(weight_mesh)API.TransformerEngineBaseModule:self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}Testing
mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11DelayedScalinghas DCP save/load disparity issues, i.e. on the scale of+/-1to theuint8parameter checkpoint!Appendix
_StridedShard- Using FSDP2 x TP Strided-ShardingWhen
redistribute'ing a global DTensor to(_StridedShard(dim=0, sf=2), Shard(dim=0)),DTensorwill perform the following steps:Shardplacements to the right of_StridedShard. (In the above example, since TP=2, the factor is 2.)[0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling_convert_param_to_dtensor_param!_StridedShard.[0] [1] [2] [3]and[4] [5] [6] [7][0 4] [1 5] [2 6] [3 7], which are assigned to the_StridedShardranks.[0 1] [2 3] [4 5] [6 7]!Shardplacement.[0] [4]/[1] [5]/[2] [6]/[3] [7], which are assigned to theShardranks.[0] [1]/[2] [3]/[4] [5]/[6] [7]!PyTorch also supports the inverse / un-sharding of this
redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)Type of change
Changes
Please list the changes introduced in this PR:
Checklist: