Enable DPA + TBO Support for Kimi K2.5 Inference#1017
Open
ftyghome wants to merge 5 commits into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR adjusts DP/TBO behaviors across MoE, scheduling, and attention decode to improve correctness under DP-attention fallback and to better coordinate cross-DP prefill alignment.
Changes:
- Update MoE token-cap sizing for DP-attn fallback and add a “keepalive” mechanism to delay tensor releases across TBO compute/comm switches.
- Refine scheduler ↔ prefill-delayer contract to include an “alignment-ready” signal (especially for TBO prefill splitting).
- Relax MLA decode “persistent mode” gating based on DP size and tweak batch-size selection logic for eager vs graph runs.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/model_ops/topK.py | Refines MORI/DP-attn gating logic into a named boolean. |
| atom/model_ops/moe.py | Adjusts MoE max token sizing under DP-attn fallback; adds TBO tensor keepalive. |
| atom/model_ops/attention_mla.py | Changes persistent-mode condition to a DP-size threshold. |
| atom/model_engine/scheduler.py | Computes local prefillability/alignment readiness for PrefillDelayer. |
| atom/model_engine/prefill_delayer.py | Extends delayer state to include alignment readiness in cross-DP reduction. |
| atom/model_engine/model_runner.py | Chooses eager-vs-graph batch sizing based on dp_uniform_decode. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from atom.quantization.quark.utils import weight_dequant_fp8 | ||
|
|
||
|
|
||
| _TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {} |
Comment on lines
+2980
to
+2998
| def _tbo_keepalive_slot(self) -> int: | ||
| try: | ||
| from atom.utils.tbo.ubatching import tbo_current_ubatch_id | ||
|
|
||
| return tbo_current_ubatch_id() | ||
| except Exception: | ||
| return 0 | ||
|
|
||
| def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None: | ||
| tensors = tuple(tensor for tensor in tensors if tensor is not None) | ||
| if tensors: | ||
| # Keep one previous tensor set per ubatch/role alive globally. | ||
| # The next same-role hold, often in the next MoE layer, happens | ||
| # after this ubatch has waited on the prior comm work, so | ||
| # overwriting here is the delayed safe release point. | ||
| key = (role, self._tbo_keepalive_slot()) | ||
| if key in _TBO_KEEPALIVE: | ||
| del _TBO_KEEPALIVE[key] | ||
| _TBO_KEEPALIVE[key] = tensors |
Comment on lines
+2063
to
+2073
| # In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs | ||
| # after all_gather_with_padding, so the token dim can be dp_size times | ||
| # the per-rank max. | ||
| moe_max_num_tokens = atom_config.max_num_batched_tokens | ||
| if ( | ||
| self.moe_parallel_config.dp_size > 1 | ||
| and not self.moe_parallel_config.use_all2all_kernels | ||
| and atom_config.enable_dp_attention | ||
| ): | ||
| moe_max_num_tokens *= self.moe_parallel_config.dp_size | ||
|
|
| else 1 / self.routed_scaling_factor | ||
| ), | ||
| max_num_tokens=atom_config.max_num_batched_tokens, | ||
| max_num_tokens=moe_max_num_tokens, |
| moe_parallel_config=self.moe_parallel_config, | ||
| in_dtype=atom_config.torch_dtype, | ||
| max_num_tokens=atom_config.max_num_batched_tokens, | ||
| max_num_tokens=moe_max_num_tokens, |
|
|
||
| dp_size = get_dp_group().world_size | ||
| use_persistent_mode = not (dp_size > 1) | ||
| use_persistent_mode = dp_size <= 8 |
Comment on lines
8
to
+10
| Mechanism (per scheduler tick): | ||
| 1. Each DP rank reports its local state via cpu all_gather: | ||
| (local_prefillable, watermark_force_allow) | ||
| (local_prefillable, local_alignment_ready, watermark_force_allow) |
Comment on lines
+40
to
+46
| use_mori_all2all = ( | ||
| dp_size > 1 | ||
| and _has_module("mori") | ||
| and config.enable_dp_attention | ||
| and config.enable_expert_parallel | ||
| ) | ||
| if use_mori_all2all: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Kimi K2.5 inference under Data-Parallel Attention (DPA) combined with Two-Batch Overlap (TBO) exposed several gaps that either crashed the engine or left performance on the table. This PR enables the DPA + TBO path end-to-end: it fixes the fused-MoE fallback and tensor lifetimes in the TBO overlap, aligns cross-DP prefill admission with TBO's two-batch requirement, and extends persistent MLA to multi-rank DP.
Technical Details
Persistent MLA for DP attention (
attention_mla.py): relaxuse_persistent_modefrom "single-rank only" (not (dp_size > 1)) todp_size <= 8, so persistent MLA also runs in the multi-rank DP configuration used by Kimi K2.5.Fix fused MoE on the DPA fallback path (
moe.py,topK.py):dp_size > 1, no MORI all2all), MoE runs afterall_gather_with_padding, so the token dim can grow todp_size ×the per-rank max. Scalemax_num_tokensfor the topK / fused-MoE metadata accordingly to avoid undersized buffers.enable_expert_parallel), so DPA-without-EP correctly falls back instead of assuming all2all.Fix TBO tensor live range (
moe.py): add a per-(role, ubatch)_TBO_KEEPALIVEholder around the all-gather and reduce-scatter comm/compute switches. Under TBO the source/output tensors of in-flight collectives could be freed before the overlapping ubatch waited on the comm; the keepalive defers release to the next same-role hold, which is past the wait point.Two-batch-aware prefill alignment (
scheduler.py,prefill_delayer.py): TBO prefill splitting needs at least two local prefill requests per DP rank. Replace_can_admit_head_prefill(boolean) with_count_admittable_head_prefills(limit)and a_prefill_delayer_readiness()helper that reports both "has any prefill" and "alignment-ready" (>= 2requests when TBO is on,>= 1otherwise).PrefillDelayergains a 4th MAX-reduce slot (local_alignment_ready) so prefill is delayed until every DP rank can launch a full two-batch, not just until one rank has a request.Fix delayer batch padding (
model_runner.py): treat a non-uniform-decode batch as eager (runs_eager = enforce_eager or not dp_uniform_decode) when picking the padded graph batch size, so mixed / delayed batches don't pad up to a CUDAGraph batch size they can't actually use.