Skip to content

[None][fix] PyExecutor Hang in Disagg TP Prefill#14020

Open
jthomson04 wants to merge 1 commit into
NVIDIA:mainfrom
jthomson04:jthomson04/tp-disagg-race
Open

[None][fix] PyExecutor Hang in Disagg TP Prefill#14020
jthomson04 wants to merge 1 commit into
NVIDIA:mainfrom
jthomson04:jthomson04/tp-disagg-race

Conversation

@jthomson04
Copy link
Copy Markdown
Collaborator

@jthomson04 jthomson04 commented May 12, 2026

In disaggregated serving on TP ≥ 2, the CTX executor can deadlock because the entry into _check_disagg_ctx_cache_transfer_status (which performs a TP-collective allgather inside CacheTransceiver::checkContextTransferStatus) is gated on rank-local values from the local scheduler. When per-rank KV-block free counts drift even by a single block (which happens under load because UCX send completion + CUDA event sync timing varies per rank), num_fitting_reqs flips from >0 to 0 on a subset of ranks. The "0 fits" ranks enter the conditional and call the allgather; the "1+ fits" ranks skip the conditional and proceed to the next phase, which has its own TP collective (MLA attention's MNNVL allreduce on the GPU; with a KV connector, an mpi_broadcast in prepare_resources). Half the ranks at one collective, the other half at another, neither completes; hang_detector trips at 300s.

In a reproduction the per-iteration [batchmgr] Capacity scheduler allows N requests log shows ranks agreeing for hundreds of iterations and then diverging on a single iteration (e.g. rank0 = 1, rank4 = 0). In the hang dump from that iteration, ranks with 1 are stuck in attention.mla_custom_op_inplace and ranks with 0 are stuck in kv_cache_transceiver.check_context_transfer_status.

This change OR-reduces the gating condition across TP ranks: if any rank locally wants the call, all ranks call it. Ranks that locally don't need it use the non-blocking variant (atLeastRequestNum=0) so the internal allgather still has full quorum without making them wait on futures they don't have. The same OR pattern is applied to the mirror site in _pp_schedule_and_propagate defensively (has_any_inflight_requests there is rank-local even though num_fitting_reqs is broadcast).

check_context_transfer_status is idempotent on a rank with no pending sender futures: the allgather contributes empty, the "complete on every rank" intersection is empty, and the future-iteration loop is a no-op. markComplete=true semantics are unchanged — a request is still only marked kDISAGG_CONTEXT_COMPLETE when every rank reports it complete. Added cost is one extra empty allgather per iteration.

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced synchronization consistency for context-cache status checking across tensor parallel ranks to prevent potential deadlock issues in distributed tensor parallelism configurations where request distributions may vary across GPUs.

Review Change Stack

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
@jthomson04 jthomson04 requested a review from a team as a code owner May 12, 2026 02:41
@jthomson04 jthomson04 requested a review from achartier May 12, 2026 02:41
@jthomson04 jthomson04 changed the title [None][fix] [None][fix] PyExecutor Hang in Disagg TP Prefill May 12, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 12, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3fce062e-8b53-4706-a9ab-3c816d63c5e4

📥 Commits

Reviewing files that changed from the base of the PR and between 64260ba and 841a05e.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py

📝 Walkthrough

Walkthrough

The PR adds Tensor Parallelism-wide allreduce consensus to disaggregated context-cache status checking, replacing rank-local conditions in the PP executor loop and batch preparation. Each rank computes whether it needs the check locally, then all ranks synchronize via ReduceOp.MAX; if any rank needs it, ranks conditionally call blocking or non-blocking variants based on both the consensus and local state.

Changes

Disaggregated context-cache allreduce consensus

Layer / File(s) Summary
ReduceOp import for allreduce operations
tensorrt_llm/_torch/pyexecutor/py_executor.py
ReduceOp is imported from the distributed communicator module to enable TP-wide MAX reduction.
PP executor loop allreduce-gated status check
tensorrt_llm/_torch/pyexecutor/py_executor.py
In the PP executor disaggregated KV-cache path, rank-local num_fitting_reqs == 0 gating is replaced with allreduce consensus: local_need_check is reduced via ReduceOp.MAX to form any_need_check, and ranks call _check_disagg_ctx_cache_transfer_status(1) only when that rank needs it and all_gen_first is false; otherwise they call the non-blocking variant to avoid deadlock.
Batch preparation allreduce-gated status check
tensorrt_llm/_torch/pyexecutor/py_executor.py
In batch preparation's disaggregated KV-cache path, the same TP-wide allreduce consensus pattern is applied; ranks synchronize on whether to enter the status collective and conditionally call blocking or non-blocking variants.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description provides a detailed explanation of the problem and solution but is missing test coverage details and the checklist section required by the template. Add 'Test Coverage' section listing relevant tests, complete the PR Checklist, and ensure the description template structure is followed.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title directly addresses the main issue: a hang/deadlock bug in PyExecutor for disaggregated tensor-parallel prefill operations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@jthomson04
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47868 [ run ] triggered by Bot. Commit: 841a05e Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants