Skip to content

Enable DPA + TBO Support for Kimi K2.5 Inference#1017

Open
ftyghome wants to merge 5 commits into
ROCm:mainfrom
RadeonFlow:rf-dpa-tbo
Open

Enable DPA + TBO Support for Kimi K2.5 Inference#1017
ftyghome wants to merge 5 commits into
ROCm:mainfrom
RadeonFlow:rf-dpa-tbo

Conversation

@ftyghome
Copy link
Copy Markdown

@ftyghome ftyghome commented Jun 1, 2026

Motivation

Kimi K2.5 inference under Data-Parallel Attention (DPA) combined with Two-Batch Overlap (TBO) exposed several gaps that either crashed the engine or left performance on the table. This PR enables the DPA + TBO path end-to-end: it fixes the fused-MoE fallback and tensor lifetimes in the TBO overlap, aligns cross-DP prefill admission with TBO's two-batch requirement, and extends persistent MLA to multi-rank DP.

Technical Details

  • Persistent MLA for DP attention (attention_mla.py): relax use_persistent_mode from "single-rank only" (not (dp_size > 1)) to dp_size <= 8, so persistent MLA also runs in the multi-rank DP configuration used by Kimi K2.5.

  • Fix fused MoE on the DPA fallback path (moe.py, topK.py):

    • In the DP-attn fallback (dp_size > 1, no MORI all2all), MoE runs after all_gather_with_padding, so the token dim can grow to dp_size × the per-rank max. Scale max_num_tokens for the topK / fused-MoE metadata accordingly to avoid undersized buffers.
    • Only select the MORI all2all path when expert parallel is actually enabled (enable_expert_parallel), so DPA-without-EP correctly falls back instead of assuming all2all.
  • Fix TBO tensor live range (moe.py): add a per-(role, ubatch) _TBO_KEEPALIVE holder around the all-gather and reduce-scatter comm/compute switches. Under TBO the source/output tensors of in-flight collectives could be freed before the overlapping ubatch waited on the comm; the keepalive defers release to the next same-role hold, which is past the wait point.

  • Two-batch-aware prefill alignment (scheduler.py, prefill_delayer.py): TBO prefill splitting needs at least two local prefill requests per DP rank. Replace _can_admit_head_prefill (boolean) with _count_admittable_head_prefills(limit) and a _prefill_delayer_readiness() helper that reports both "has any prefill" and "alignment-ready" (>= 2 requests when TBO is on, >= 1 otherwise). PrefillDelayer gains a 4th MAX-reduce slot (local_alignment_ready) so prefill is delayed until every DP rank can launch a full two-batch, not just until one rank has a request.

  • Fix delayer batch padding (model_runner.py): treat a non-uniform-decode batch as eager (runs_eager = enforce_eager or not dp_uniform_decode) when picking the padded graph batch size, so mixed / delayed batches don't pad up to a CUDAGraph batch size they can't actually use.

Copilot AI review requested due to automatic review settings June 1, 2026 14:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR adjusts DP/TBO behaviors across MoE, scheduling, and attention decode to improve correctness under DP-attention fallback and to better coordinate cross-DP prefill alignment.

Changes:

  • Update MoE token-cap sizing for DP-attn fallback and add a “keepalive” mechanism to delay tensor releases across TBO compute/comm switches.
  • Refine scheduler ↔ prefill-delayer contract to include an “alignment-ready” signal (especially for TBO prefill splitting).
  • Relax MLA decode “persistent mode” gating based on DP size and tweak batch-size selection logic for eager vs graph runs.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
atom/model_ops/topK.py Refines MORI/DP-attn gating logic into a named boolean.
atom/model_ops/moe.py Adjusts MoE max token sizing under DP-attn fallback; adds TBO tensor keepalive.
atom/model_ops/attention_mla.py Changes persistent-mode condition to a DP-size threshold.
atom/model_engine/scheduler.py Computes local prefillability/alignment readiness for PrefillDelayer.
atom/model_engine/prefill_delayer.py Extends delayer state to include alignment readiness in cross-DP reduction.
atom/model_engine/model_runner.py Chooses eager-vs-graph batch sizing based on dp_uniform_decode.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_ops/moe.py
from atom.quantization.quark.utils import weight_dequant_fp8


_TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {}
Comment thread atom/model_ops/moe.py
Comment on lines +2980 to +2998
def _tbo_keepalive_slot(self) -> int:
try:
from atom.utils.tbo.ubatching import tbo_current_ubatch_id

return tbo_current_ubatch_id()
except Exception:
return 0

def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None:
tensors = tuple(tensor for tensor in tensors if tensor is not None)
if tensors:
# Keep one previous tensor set per ubatch/role alive globally.
# The next same-role hold, often in the next MoE layer, happens
# after this ubatch has waited on the prior comm work, so
# overwriting here is the delayed safe release point.
key = (role, self._tbo_keepalive_slot())
if key in _TBO_KEEPALIVE:
del _TBO_KEEPALIVE[key]
_TBO_KEEPALIVE[key] = tensors
Comment thread atom/model_ops/moe.py
Comment on lines +2063 to +2073
# In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs
# after all_gather_with_padding, so the token dim can be dp_size times
# the per-rank max.
moe_max_num_tokens = atom_config.max_num_batched_tokens
if (
self.moe_parallel_config.dp_size > 1
and not self.moe_parallel_config.use_all2all_kernels
and atom_config.enable_dp_attention
):
moe_max_num_tokens *= self.moe_parallel_config.dp_size

Comment thread atom/model_ops/moe.py
else 1 / self.routed_scaling_factor
),
max_num_tokens=atom_config.max_num_batched_tokens,
max_num_tokens=moe_max_num_tokens,
Comment thread atom/model_ops/moe.py
moe_parallel_config=self.moe_parallel_config,
in_dtype=atom_config.torch_dtype,
max_num_tokens=atom_config.max_num_batched_tokens,
max_num_tokens=moe_max_num_tokens,

dp_size = get_dp_group().world_size
use_persistent_mode = not (dp_size > 1)
use_persistent_mode = dp_size <= 8
Comment on lines 8 to +10
Mechanism (per scheduler tick):
1. Each DP rank reports its local state via cpu all_gather:
(local_prefillable, watermark_force_allow)
(local_prefillable, local_alignment_ready, watermark_force_allow)
Comment thread atom/model_ops/topK.py
Comment on lines +40 to +46
use_mori_all2all = (
dp_size > 1
and _has_module("mori")
and config.enable_dp_attention
and config.enable_expert_parallel
)
if use_mori_all2all:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants