[TRTLLM-12589][fix] Reset MoE A2A dispatch state on warmup OOM#14000
Open
Barry-Delaney wants to merge 1 commit into
Open
[TRTLLM-12589][fix] Reset MoE A2A dispatch state on warmup OOM#14000Barry-Delaney wants to merge 1 commit into
Barry-Delaney wants to merge 1 commit into
Conversation
The MoE all-to-all dispatch/combine state machine (MoeAlltoAll._state.phase and NVLinkOneSided._dispatch_state["phase"]) is left in "dispatched" when a forward raises between dispatch() and combine(). _general_warmup_impl's "try / except OutOfMemoryError" catches the OOM and clears CUDA cache but does not reset Python-side stateful objects, so the next warmup config's forward fails the dispatch() invariant with "dispatch called twice without an intervening combine" This converts the "skip this warmup shape and try a smaller one" behavior the framework intended into "die at server startup with a confusing assert that has nothing to do with memory". Fix: * Add reset_state() to MoeAlltoAll and NVLinkOneSided. * In _general_warmup_impl's OOM handler, walk self.model.modules() and call reset_state() on any moe_a2a / comm submodule that exposes it, before retrying the next warmup config. Non-OOM exceptions are deliberately not covered: they escape the `except OutOfMemoryError` block, propagate through PyExecutor.__init__ and kill the server before any further dispatch() runs, so the stuck state machine never gets exercised again -- guarding that path would violate "don't guard against scenarios that can't happen". This restores the documented "skip OOMing config" semantics. The underlying 4096-token warmup OOM at kv_cache_free_gpu_memory_fraction=0.9 is a separate KV pool sizing issue and is not addressed here. Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
c0e4eb5 to
2e0d90b
Compare
Collaborator
Author
|
/bot run --disable-fail-fast |
Collaborator
|
PR_Github #47758 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #47758 [ run ] completed with state
|
Collaborator
Author
|
/bot run --disable-fail-fast |
Collaborator
|
PR_Github #47806 [ run ] triggered by Bot. Commit: |
Collaborator
|
PR_Github #47806 [ run ] completed with state
|
Collaborator
Author
|
Waiting on #14028 for fixing CI. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The MoE all-to-all dispatch/combine state machine (MoeAlltoAll._state.phase and NVLinkOneSided._dispatch_state["phase"]) is left in "dispatched" when a forward raises between dispatch() and combine(). _general_warmup_impl's "try / except OutOfMemoryError" catches the OOM and clears CUDA cache but does not reset Python-side stateful objects, so the next warmup config's forward fails the dispatch() invariant with
"dispatch called twice without an intervening combine"
This converts the "skip this warmup shape and try a smaller one" behavior the framework intended into "die at server startup with a confusing assert that has nothing to do with memory".
Fix:
Non-OOM exceptions are deliberately not covered: they escape the
except OutOfMemoryErrorblock, propagate through PyExecutor.init and kill the server before any further dispatch() runs, so the stuck state machine never gets exercised again -- guarding that path would violate "don't guard against scenarios that can't happen".This restores the documented "skip OOMing config" semantics. The underlying 4096-token warmup OOM at kv_cache_free_gpu_memory_fraction=0.9 is a separate KV pool sizing issue and is not addressed here.