Skip to content

[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models#7727

Merged
Jiang-Jia-Jun merged 10 commits intoPaddlePaddle:developfrom
chang-wenbin:CHUNK_MLA
May 9, 2026
Merged

[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models#7727
Jiang-Jia-Jun merged 10 commits intoPaddlePaddle:developfrom
chang-wenbin:CHUNK_MLA

Conversation

@chang-wenbin
Copy link
Copy Markdown
Collaborator

@chang-wenbin chang-wenbin commented May 7, 2026

Motivation

MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:

  1. get_position_ids_and_mask_encoder_batch.cu 中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。
  2. forward_extend / forward_mixed 中 FlashAttention 调用未将 cached KV tokens 纳入 cu_seqlens_kmax_seqlen_k,导致 attention tile 被截断,输出静默损坏。
  3. 缺少从 paged latent cache 中读取已缓存 KV 并与新 token KV interleave 的机制。

Modifications

  • custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。
  • fastdeploy/model_executor/layers/attention/mla_attention_backend.py
    • 新增 fused_read_cache_and_interleave_naive(Python 参考实现)和 fused_read_cache_and_interleave_triton(Triton 加速版),统一通过 fused_read_cache_and_interleave 入口(环境变量 FD_MLA_USE_NAIVE=1 切换)。
    • MLAAttentionMetadata 增加 prefix cache 相关字段(has_prefix_cachecu_seqlens_cached_kvcu_seqlens_k_with_cachemax_total_kv_len 等)。
    • init_attention_metadata 中新增 prefix cache 元数据计算逻辑。
    • forward_extend / forward_mixed 中 FlashAttention 调用使用带 cache 的 cu_seqlens_k_with_cachemax_total_kv_len
  • fastdeploy/model_executor/models/deepseek_v3.py:prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,key tensor shape 调整为 [full_tokens, heads, qk_head_dim]

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


chang-wenbin seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 7, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-09 13:57:32

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

CI 正在进行中,所有 Required 任务均未失败,6 个 Required 任务运行中,等待结果。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
36(0) 36 26 2 7 1 0

2 任务状态汇总

2.1 Required任务 : 4/10 通过

必选任务阻塞合并,失败需优先处理。当前 6 个必选任务运行中,请等待结果。

状态 任务 耗时 根因 修复建议 日志 重跑
run_ce_cases - 运行中 - CI 详情 -
base_tests - 运行中 - CI 详情 -
run_tests_with_coverage - 运行中 - CI 详情 -
run_4_cards_tests - 运行中 - CI 详情 -
stable_tests - 运行中 - CI 详情 -
run_xpu_4cards_cases - 运行中 - CI 详情 -
其余 4 个必选任务通过 - - - - -

2.2 可选任务 — 22/26 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Check PR Template 10s Job -
Trigger Jenkins for PR 9m28s Job -
run_iluvatar_cases - CI 详情 -
⏸️ CI_HPU - - -
其余 22 个可选任务通过 - - -

3 失败详情(仅 required)

无 required 失败任务。

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 7, 2026

Codecov Report

❌ Patch coverage is 46.53465% with 54 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@9d3bb29). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...executor/layers/attention/mla_attention_backend.py 53.40% 41 Missing ⚠️
fastdeploy/model_executor/models/deepseek_v3.py 0.00% 13 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7727   +/-   ##
==========================================
  Coverage           ?   71.55%           
==========================================
  Files              ?      396           
  Lines              ?    55689           
  Branches           ?     8703           
==========================================
  Hits               ?    39850           
  Misses             ?    13098           
  Partials           ?     2741           
Flag Coverage Δ
GPU 71.55% <46.53%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@chang-wenbin chang-wenbin changed the title support mla chunk-prefill & prefix_cache [Feature][KVCache] Support MLA chunk-prefill & prefix cache for DeepSeek-V3 May 8, 2026
@Jiang-Jia-Jun Jiang-Jia-Jun requested review from Copilot and removed request for PaddlePaddle-bot May 8, 2026 08:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在让 DeepSeek-V3 的 MLA attention 在 prefix cache + chunked prefill 组合场景下保持正确性:修复 position_ids 的起点/offset 计算,并在 prefill/mixed 的 FlashAttention 调用中把 cached KV 纳入 cu_seqlens_k/max_seqlen_k,同时补齐从 paged latent cache 读取并与新 token latent 交错的路径。

Changes:

  • 修复 get_position_ids_and_mask_encoder_batch 在 chunked prefill 场景下 offset 叠加导致的 position_ids 错误,并引入“cached_len 作为 position 起点”的写入逻辑。
  • MLAAttentionBackend 中新增 fused read-cache + interleave(naive/Triton)并扩展 MLAAttentionMetadata,prefill/mixed 使用带 cache 的 cu_seqlens_k_with_cachemax_total_kv_len 调用 FlashAttention。
  • DeepSeek-V3 prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,同时调整 key tensor 的 shape 构造以适配全量 KV token。

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
fastdeploy/model_executor/models/deepseek_v3.py prefill 分支接入 prefix cache:从 paged latent cache 读取并与新 token latent 交错后再做 KV projection;调整 key shape/赋值逻辑
fastdeploy/model_executor/layers/attention/mla_attention_backend.py 增加 fused read-cache+interleave(naive/Triton)与 prefix-cache 元数据;prefill/mixed 的 FlashAttention 使用包含 cache 的 seqlens/maxlen;替换 print 为 logger
custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu 修复 chunked prefill + prefix cache 下 position_ids 的 offset 与起点计算逻辑

Comment on lines +385 to +393
if need_do_prefill: # max_enc_len_this_time
key_value = self.kv_b_proj(compressed_kv)
# Check for prefix cache
attn_meta = forward_meta.attn_backend.attention_metadata if hasattr(forward_meta, "attn_backend") else None
has_prefix_cache = False
total_cached_tokens = 0

if attn_meta is not None and isinstance(attn_meta, MLAAttentionMetadata):
has_prefix_cache = attn_meta.has_prefix_cache
total_cached_tokens = attn_meta.total_cached_kv_tokens
Comment on lines +399 to +414
if has_prefix_cache and total_cached_tokens > 0:
layer_id = self.mla_attn.layer_id if hasattr(self.mla_attn, "layer_id") else 0
latent_cache = forward_meta.caches[layer_id] if hasattr(forward_meta, "caches") else None
if latent_cache is not None:
block_size = self.mla_attn.block_size if hasattr(self.mla_attn, "block_size") else 64
full_compressed_kv, full_k_pe = fused_read_cache_and_interleave(
latent_cache,
forward_meta.block_tables,
compressed_kv,
key_pe.squeeze(1),
attn_meta.cu_seqlens_cached_kv,
forward_meta.cu_seqlens_q,
self.kv_lora_rank,
self.qk_rope_head_dim,
block_size,
)
Comment on lines +185 to +205
bsz = cu_seqlens_cached_kv.shape[0] - 1
cu_cached = cu_seqlens_cached_kv.tolist()
cu_new = cu_seqlens_q.tolist()
total_cached = int(cu_cached[bsz])
total_new = new_compressed_kv.shape[0]
total_tokens = total_cached + total_new

full_compressed_kv = paddle.empty([total_tokens, kv_lora_rank], dtype=new_compressed_kv.dtype)
full_k_pe = paddle.empty([total_tokens, qk_rope_head_dim], dtype=new_k_pe.dtype)
if total_tokens == 0:
return full_compressed_kv, full_k_pe

# block_tables.tolist() is a one-shot D2H; acceptable since host-side loop
# already requires CPU iteration over total_tokens.
bt_list = block_tables.tolist()

is_cached_host = [0] * total_tokens
src_off_host = [0] * total_tokens
out_pos = 0
for b in range(bsz):
nc = int(cu_cached[b + 1]) - int(cu_cached[b])
Comment on lines +244 to +246
def fused_read_cache_and_interleave(*args, **kwargs):
"""Unified entry. ``FD_MLA_USE_NAIVE=1`` forces the Python reference path."""
if os.environ.get("FD_MLA_USE_NAIVE", "0") == "1":
Comment on lines +29 to 41
// 动态计算当前批次的偏移量。
// 每个 batch 只会贡献 encoder_len 或 seq_lens_this_time 中的一个,
// 而非两者之和(chunked prefill 时 encoder_len > 0 与 decoder_len > 0
// 同时成立,
// 但该 batch 只有 encoder_len 个真实 token)。
int offset = 0;
for (int i = 0; i < tid; i++) {
offset += seq_lens_encoder[i];
if (seq_lens_decoder[i] > 0) {
if (seq_lens_encoder[i] > 0) {
offset += seq_lens_encoder[i];
} else if (seq_lens_decoder[i] > 0) {
offset += seq_lens_this_time[i];
}
}
Comment on lines +354 to +359
"""MLA attention forward with prefix cache support."""

from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
MLAAttentionMetadata,
fused_read_cache_and_interleave,
)
PaddlePaddle-bot

This comment was marked as outdated.

cu_total = [0] * (bsz + 1)
cumsum_cached = 0
cumsum_total = 0
for i in range(bsz):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的操作是否可以用自定义算子来处理,取消D2H和H2D 也能简化CPU复杂度

PaddlePaddle-bot

This comment was marked as outdated.

@chang-wenbin chang-wenbin changed the title [Feature][KVCache] Support MLA chunk-prefill & prefix cache for DeepSeek-V3 [Feature][KVCache] Support MLA chunk-prefill & prefix cache for all MLA architecture model May 8, 2026
@chang-wenbin chang-wenbin changed the title [Feature][KVCache] Support MLA chunk-prefill & prefix cache for all MLA architecture model [Feature][KVCache] Support MLA chunk-prefill & prefix cache for all MLA Architecture Model May 8, 2026
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@chang-wenbin chang-wenbin changed the title [Feature][KVCache] Support MLA chunk-prefill & prefix cache for all MLA Architecture Model [Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models May 8, 2026
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-09 13:26:38

📋 Review 摘要

PR 概述:为所有 MLA 架构模型(DeepSeek V3 等)补全 chunked prefill + prefix cache 联合支持,修复 position_ids 计算错误、cu_seqlens_k 截断和 latent KV 读取缺失三处 bug。
变更范围custom_ops/gpu_ops/fastdeploy/model_executor/layers/attention/mla_attention_backend.pyfastdeploy/model_executor/models/deepseek_v3.py
影响面 Tag[Models] [OP] [KVCache]

📝 PR 规范检查

PR 标题 [Feature] 为合规官方 Tag,标题本身无需修改。但 ## Usage or Command## Accuracy Tests 两 section 仅保留 HTML 注释占位符(未填写 N/A),Checklist 全部未勾选,不符合模板要求。

标题建议(可直接复制):

  • [Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:
1. `get_position_ids_and_mask_encoder_batch.cu` 中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。
2. `forward_extend` / `forward_mixed` 中 FlashAttention 调用未将 cached KV tokens 纳入 `cu_seqlens_k``max_seqlen_k`,导致 attention tile 被截断,输出静默损坏。
3. 缺少从 paged latent cache 中读取已缓存 KV 并与新 token KV interleave 的机制。

## Modifications
- `custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu`:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。
- `custom_ops/gpu_ops/get_padding_offset.cu` / `cpp_extensions.cc`:新增 `seq_lens_decoder` 可选参数,使 `cu_seqlens_k` 能包含 cached KV 长度,`cu_seqlens_q` 仍为纯新 token 累积和。
- `fastdeploy/model_executor/layers/attention/mla_attention_backend.py`- 新增 `fused_read_cache_and_interleave_naive`(Python 参考实现)和 `fused_read_cache_and_interleave_triton`(Triton 加速版),统一通过 `fused_read_cache_and_interleave` 入口(`FD_MLA_USE_NAIVE=1` 切换)。
  - `MLAAttentionMetadata` 新增 `max_seqlen_k` 字段,用于 FlashAttention K 侧长度。
  - `forward_extend` / `forward_mixed` 的 FlashAttention 调用使用 `max_seqlen_k` 替代 `max_enc_len_this_time` 作为 K 侧最大长度。
- `fastdeploy/model_executor/models/deepseek_v3.py`:prefill 分支在 chunked prefill 或 prefix cache 开启时,先通过 `fused_read_cache_and_interleave` 读取 cached latent 并拼接再做 KV projection;key tensor shape 调整为包含全量 tokens。
- `fastdeploy/model_executor/pre_and_post_process.py``get_padding_offset` 调用增加 `seq_lens_decoder` 参数。
- 测试:新增 `tests/model_executor/test_mla_fused_read_interleave.py`,覆盖 Triton 与 Python 参考路径精度对比;更新 `tests/operators/` 相关测试函数签名。

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🟡 建议 mla_attention_backend.py:587 max_seqlen_k 计算语义存疑:取 max(max_kv_len_this_time, max_enc_len_this_time) 为两者最大值,而 prefix cache 场景下正确值为各批次 (cached_len + new_len) 之和的最大值
🟡 建议 deepseek_v3.py:397 forward_meta.block_tables 传入 fused_read_cache_and_interleave 前缺少 None 检查,全为 first-chunk prefill 时可能 crash
🟡 建议 deepseek_v3.py A1 必查项:前向路径改动缺少端到端精度对比(如与 HuggingFace logits 对齐的测试结果)
❓ 疑问 mla_attention_backend.py:637 MLAAttentionMetadata 内注释残留 # forward_meta.cu_seqlens_k: Optional[...] = None,建议清理

🟡 mla_attention_backend.py:587max_seqlen_k 语义存疑

# 当前代码
metadata.max_seqlen_k = max(metadata.max_kv_len_this_time.item(), metadata.max_enc_len_this_time.item())

FlashAttention 的 max_seqlen_k 参数需要等于 max_i(cu_seqlens_k[i+1] - cu_seqlens_k[i]),即各批次实际 KV 序列长度(cached + new)的最大值。当前取 max(max_kv, max_enc) 得到的是两者分别的最大值,而非每批次之和的最大值。若 max_len_tensor_cpu[5] 已预计算了含 cached+new 的总 KV 最大长度,则此写法正确;若它仅为 decode 阶段最大长度,此处将低估 max_seqlen_k,导致 FlashAttention tile 截断(本 PR 修复动机所描述的静默损坏类型)。建议在注释中说明 max_len_tensor_cpu[5] 语义,或从 cu_seqlens_k 直接反推:

metadata.max_seqlen_k = int((forward_meta.cu_seqlens_k[1:] - forward_meta.cu_seqlens_k[:-1]).max())

🟡 deepseek_v3.py:397block_tables 缺少 None 防御

if self.enable_chunked_prefill or self.enable_prefix_caching:
    full_compressed_kv, full_k_pe = fused_read_cache_and_interleave(
        forward_meta.caches[self.layer_id],
        forward_meta.block_tables,   # ← 可能为 None
        ...
    )

当开启 chunked prefill / prefix cache 但当前批次全为首轮 prefill(无 cache 命中)时,forward_meta.block_tables 可能为 None[bsz, 0],传入 Triton kernel 会触发段错误。建议加 None guard:

if (self.enable_chunked_prefill or self.enable_prefix_caching) \
        and forward_meta.block_tables is not None:
    full_compressed_kv, full_k_pe = fused_read_cache_and_interleave(...)

mla_attention_backend.py:637 — 注释残留

# forward_meta.cu_seqlens_k: Optional[paddle.Tensor] = None 是被注释掉的伪字段声明,与上方注释语义重复且会误导阅读者,建议删除。

总体评价

PR 思路清晰,三处 MLA prefix cache bug 定位准确;Triton kernel 采用编译期二分搜索+BLOCK_M tiling 设计合理,手工基准测试覆盖完整。建议:① 确认 max_seqlen_k 的计算语义或改用 cu_seqlens_k 直接推导;② 对 block_tables 加 None 防御;③ 在 PR 描述 Accuracy Tests 节补充端到端精度验证结果。

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 6248358 into PaddlePaddle:develop May 9, 2026
33 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants