Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992
Open
roycho96 wants to merge 1 commit intodeepspeedai:masterfrom
Open
Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992roycho96 wants to merge 1 commit intodeepspeedai:masterfrom
roycho96 wants to merge 1 commit intodeepspeedai:masterfrom
Conversation
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
4ea6fc2 to
d4e71b7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
engine.coalesce_grad_reduction(), an opt-in context manager that defers ZeRO 1/2/3 gradient reduction across multipleengine.backward()calls inside one optimizer step. On context exit, a single reduction pass populatesaveraged_gradientsfor the nextengine.step().This is the third step of the multi-backward feature:
set_gradient_accumulation_boundary()plus manualengine.backward()(PyTorch-style backward) a first-class API.cpu_offload. Chunks 1 through N-1 were dropped atga_steps=1 + N>1because of an outer gate incopy_grads_in_partition.This PR makes the path efficient. N reduce-scatters collapse into 1 across all ZeRO stages, removing the communication bottleneck that remained after the correctness fix.
Motivation
engine.no_sync()(engine.py) explicitly asserts ZeRO 2/3 are incompatible with the no_sync context manager. ZeRO needs per-backward reduction to partition gradients. The assert enforces that, but it blocks patterns where multipleengine.backward()calls per step are intentional:CachedMultipleNegativesRankingLoss,CachedGISTEmbedLoss,CachedMultipleNegativesSymmetricRankingLossall callengine.backward()once per cached chunk.torch.autograd.backward()inside their forwardBoth rely on the PyTorch-style backward API from #7665. With that API, the user (or a custom autograd Function) issues N
engine.backward()calls perengine.step()and togglesset_gradient_accumulation_boundary()to mark the last one. Without this PR, the pattern issues N reduce-scatters per step on ZeRO 2/3 even when the math only needs 1.What changed
deepspeed/runtime/engine.pyengine.coalesce_grad_reduction()context manager._flush_coalesced_reduction_zero{12,3}). Iterates params explicitly instead of callingreduce_gradients()to bypass theoverlap_commshort-circuit and thecontiguous_gradientssetup_buckets dependency.deepspeed/runtime/zero/stage_1_and_2.pyanddeepspeed/runtime/zero/stage3.py_coalesce_grad_reduction = Falseinit plus a 2-line guard at the top of the per-param reducer entry point. No existing function bodies modified.Compatibility matrix (all bit-exact vs. baseline multi-backward)
All four (contiguous_gradients, overlap_comm) combinations bit-exact vs. baseline multi-backward:
Additional verified:
gradient_accumulation_dtype=fp32(Z2 directly, Z1 with offload via theuse_grad_accum_attribute=Truepath).reduce_bucket_size).TestCoalesceCollectiveCount, patchesdist.all_reduce/reduce/reduce_scatter_fn).optimizer.micro_step_idinvariant. Stays 0 at flush across multiple steps, sopartition_gradsalways takes thecopy_branch instead of the stale-bufferadd_branch (TestCoalesceZero3MicroStepInvariant).Unsupported (NotImplementedError)
grad_accum_dtype=fp32, and no offload. Users on this combo can switch to ZeRO-2.engine.no_sync().