Fix Stage 0 + Ulysses crash: make bwc_tensor_model_parallel_rank() resilient to MP API absence#7888
Fix Stage 0 + Ulysses crash: make bwc_tensor_model_parallel_rank() resilient to MP API absence#7888nathon-lee wants to merge 8 commits intodeepspeedai:masterfrom
Conversation
|
Hi @nathon-lee, I found that we already have a fallback from |
|
Hi @tohtana, thanks for looking into this! I double-checked
My understanding is that #7649 added/adjusted APIs and fallbacks around world size (e.g., That’s why I opened #7888: it makes If you have a preferred behavior (e.g., should we fallback to a sequence-parallel rank instead of 0?), I’m happy to adjust the PR. |
Add check for model parallel rank in mpu. Signed-off-by: nathon-lee <leejianwoo@gmail.com>
|
Hi @nathon-lee, |
Title
Fix Stage 0 + Ulysses crash: make
bwc_tensor_model_parallel_rank()resilient to MP API absenceSummary
This PR fixes a hard crash when using Ulysses sequence parallelism with ZeRO Stage 0 (BF16_Optimizer).
In this configuration, DeepSpeed calls
deepspeed.utils.bwc.bwc_tensor_model_parallel_rank(mpu=...), and the passedmpuobject can bedeepspeed.runtime.sequence_parallel.parallel_state_sp, which does not implement the deprecatedget_model_parallel_rank()API. The current fallback path unconditionally callsmpu.get_model_parallel_rank(), raisingAttributeError.The fix adds a defensive capability check before calling the deprecated API. If the provided
mpudoes not expose any known tensor/model-parallel rank API, we treat it as “no tensor model parallelism” and return rank0.Motivation / Context
AttributeError: ... parallel_state_sp has no attribute get_model_parallel_rankbwc_tensor_model_parallel_rank()falls back to a deprecated API without anhasattr()check.This change keeps the original priority order intact:
get_tensor_model_parallel_rank()get_slice_parallel_rank()get_model_parallel_rank()(deprecated)0if none existChanges
deepspeed/utils/bwc.pybwc_tensor_model_parallel_rank()to checkhasattr(mpu, "get_model_parallel_rank")before calling it.mpuprovides none of the expected tensor/model-parallel rank APIs, return0(no TP).Why this is safe
get_tensor_model_parallel_rank()orget_slice_parallel_rank()orget_model_parallel_rank(), behavior is unchanged.mpuobject does not provide any of these methods.Reproduction
Using the Ulysses ALST tutorial flow, switching ZeRO stage from 3 to 0 triggers the crash during optimizer step when grad norm is computed.
Testing
bwc_tensor_model_parallel_rank(mpu=deepspeed.runtime.sequence_parallel.parallel_state_sp)should no longer raise.References