Skip to content

Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992

Open
roycho96 wants to merge 1 commit intodeepspeedai:masterfrom
roycho96:feat/zero-coalesce-grad-reduction
Open

Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992
roycho96 wants to merge 1 commit intodeepspeedai:masterfrom
roycho96:feat/zero-coalesce-grad-reduction

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

@roycho96 roycho96 commented May 5, 2026

Summary

Adds engine.coalesce_grad_reduction(), an opt-in context manager that defers ZeRO 1/2/3 gradient reduction across multiple engine.backward() calls inside one optimizer step. On context exit, a single reduction pass populates averaged_gradients for the next engine.step().

This is the third step of the multi-backward feature:

This PR makes the path efficient. N reduce-scatters collapse into 1 across all ZeRO stages, removing the communication bottleneck that remained after the correctness fix.

Motivation

engine.no_sync() (engine.py) explicitly asserts ZeRO 2/3 are incompatible with the no_sync context manager. ZeRO needs per-backward reduction to partition gradients. The assert enforces that, but it blocks patterns where multiple engine.backward() calls per step are intentional:

  • Cached contrastive learning (GradCache): sentence-transformers CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss, CachedMultipleNegativesSymmetricRankingLoss all call engine.backward() once per cached chunk.
  • Custom autograd Functions that invoke torch.autograd.backward() inside their forward

Both rely on the PyTorch-style backward API from #7665. With that API, the user (or a custom autograd Function) issues N engine.backward() calls per engine.step() and toggles set_gradient_accumulation_boundary() to mark the last one. Without this PR, the pattern issues N reduce-scatters per step on ZeRO 2/3 even when the math only needs 1.

What changed

deepspeed/runtime/engine.py

  • New engine.coalesce_grad_reduction() context manager.
  • Stage-aware flush helpers (_flush_coalesced_reduction_zero{12,3}). Iterates params explicitly instead of calling reduce_gradients() to bypass the overlap_comm short-circuit and the contiguous_gradients setup_buckets dependency.

deepspeed/runtime/zero/stage_1_and_2.py and deepspeed/runtime/zero/stage3.py

  • _coalesce_grad_reduction = False init plus a 2-line guard at the top of the per-param reducer entry point. No existing function bodies modified.

Compatibility matrix (all bit-exact vs. baseline multi-backward)

All four (contiguous_gradients, overlap_comm) combinations bit-exact vs. baseline multi-backward:

Stage (F, F) (T, F) (F, T) (T, T) default
ZeRO-1 OK OK OK OK
ZeRO-2 OK OK OK OK
ZeRO-3 OK OK OK OK

Additional verified:

  • CPU offload (offload_optimizer Z1/Z2/Z3, offload_param Z3).
  • BF16 with gradient_accumulation_dtype=fp32 (Z2 directly, Z1 with offload via the use_grad_accum_attribute=True path).
  • FP16 with dynamic loss scaling (Z1/Z2/Z3).
  • Multi-bucket flush (small reduce_bucket_size).
  • MoE smoke (ep_size=1, Z1/Z2). MoE ep_size=2 test included but requires world_size=4.
  • Gradient clipping, multi-step state hygiene.
  • N=4 deferred backward issues strictly fewer cross-rank collectives than baseline (TestCoalesceCollectiveCount, patches dist.all_reduce/reduce/reduce_scatter_fn).
  • ZeRO-3 optimizer.micro_step_id invariant. Stays 0 at flush across multiple steps, so partition_grads always takes the copy_ branch instead of the stale-buffer add_ branch (TestCoalesceZero3MicroStepInvariant).

Unsupported (NotImplementedError)

  • ZeRO stage 0.
  • BF16_Optimizer / FP16_Optimizer wrappers. BF16_Optimizer dispatches only for ZeRO-1 with bf16, grad_accum_dtype=fp32, and no offload. Users on this combo can switch to ZeRO-2.
  • PipelineModule (pipeline parallelism schedules its own reductions).
  • Reentry / nesting with engine.no_sync().

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the feat/zero-coalesce-grad-reduction branch from 4ea6fc2 to d4e71b7 Compare May 5, 2026 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant