Skip to content

[TRTLLM-12589][fix] Reset MoE A2A dispatch state on warmup OOM#14000

Open
Barry-Delaney wants to merge 1 commit into
NVIDIA:feat/deepseek_v4from
Barry-Delaney:user/jinshik/fix-moe-a2a-warmup-oom
Open

[TRTLLM-12589][fix] Reset MoE A2A dispatch state on warmup OOM#14000
Barry-Delaney wants to merge 1 commit into
NVIDIA:feat/deepseek_v4from
Barry-Delaney:user/jinshik/fix-moe-a2a-warmup-oom

Conversation

@Barry-Delaney
Copy link
Copy Markdown
Collaborator

The MoE all-to-all dispatch/combine state machine (MoeAlltoAll._state.phase and NVLinkOneSided._dispatch_state["phase"]) is left in "dispatched" when a forward raises between dispatch() and combine(). _general_warmup_impl's "try / except OutOfMemoryError" catches the OOM and clears CUDA cache but does not reset Python-side stateful objects, so the next warmup config's forward fails the dispatch() invariant with

"dispatch called twice without an intervening combine"

This converts the "skip this warmup shape and try a smaller one" behavior the framework intended into "die at server startup with a confusing assert that has nothing to do with memory".

Fix:

  • Add reset_state() to MoeAlltoAll and NVLinkOneSided.
  • In _general_warmup_impl's OOM handler, walk self.model.modules() and call reset_state() on any moe_a2a / comm submodule that exposes it, before retrying the next warmup config.

Non-OOM exceptions are deliberately not covered: they escape the except OutOfMemoryError block, propagate through PyExecutor.init and kill the server before any further dispatch() runs, so the stuck state machine never gets exercised again -- guarding that path would violate "don't guard against scenarios that can't happen".

This restores the documented "skip OOMing config" semantics. The underlying 4096-token warmup OOM at kv_cache_free_gpu_memory_fraction=0.9 is a separate KV pool sizing issue and is not addressed here.

@Barry-Delaney Barry-Delaney requested review from a team as code owners May 11, 2026 14:43
@Barry-Delaney Barry-Delaney requested review from QiJune, joyang-nv and leslie-fang25 and removed request for a team May 11, 2026 14:43
@Barry-Delaney Barry-Delaney requested a review from lfr-0531 May 11, 2026 14:43
The MoE all-to-all dispatch/combine state machine (MoeAlltoAll._state.phase
and NVLinkOneSided._dispatch_state["phase"]) is left in "dispatched" when a
forward raises between dispatch() and combine(). _general_warmup_impl's
"try / except OutOfMemoryError" catches the OOM and clears CUDA cache but
does not reset Python-side stateful objects, so the next warmup config's
forward fails the dispatch() invariant with

  "dispatch called twice without an intervening combine"

This converts the "skip this warmup shape and try a smaller one" behavior
the framework intended into "die at server startup with a confusing assert
that has nothing to do with memory".

Fix:
* Add reset_state() to MoeAlltoAll and NVLinkOneSided.
* In _general_warmup_impl's OOM handler, walk self.model.modules() and
  call reset_state() on any moe_a2a / comm submodule that exposes it,
  before retrying the next warmup config.

Non-OOM exceptions are deliberately not covered: they escape the
`except OutOfMemoryError` block, propagate through PyExecutor.__init__
and kill the server before any further dispatch() runs, so the stuck
state machine never gets exercised again -- guarding that path would
violate "don't guard against scenarios that can't happen".

This restores the documented "skip OOMing config" semantics. The
underlying 4096-token warmup OOM at kv_cache_free_gpu_memory_fraction=0.9
is a separate KV pool sizing issue and is not addressed here.

Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
@Barry-Delaney Barry-Delaney force-pushed the user/jinshik/fix-moe-a2a-warmup-oom branch from c0e4eb5 to 2e0d90b Compare May 11, 2026 14:43
@Barry-Delaney
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47758 [ run ] triggered by Bot. Commit: 2e0d90b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47758 [ run ] completed with state SUCCESS. Commit: 2e0d90b
/LLM/main/L0_MergeRequest_PR pipeline #37650 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@Barry-Delaney
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47806 [ run ] triggered by Bot. Commit: 2e0d90b Link to invocation

@Barry-Delaney Barry-Delaney changed the title [None][fix] Reset MoE A2A dispatch state on warmup OOM [TRTLLM-12589][fix] Reset MoE A2A dispatch state on warmup OOM May 12, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47806 [ run ] completed with state SUCCESS. Commit: 2e0d90b
/LLM/main/L0_MergeRequest_PR pipeline #37697 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@Barry-Delaney
Copy link
Copy Markdown
Collaborator Author

Waiting on #14028 for fixing CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants