[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models#7727
Conversation
|
chang-wenbin seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
|
Thanks for your contribution! |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览CI 正在进行中,所有 Required 任务均未失败,6 个 Required 任务运行中,等待结果。
2 任务状态汇总2.1 Required任务 : 4/10 通过
2.2 可选任务 — 22/26 通过
3 失败详情(仅 required)无 required 失败任务。 |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #7727 +/- ##
==========================================
Coverage ? 71.55%
==========================================
Files ? 396
Lines ? 55689
Branches ? 8703
==========================================
Hits ? 39850
Misses ? 13098
Partials ? 2741
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
该 PR 旨在让 DeepSeek-V3 的 MLA attention 在 prefix cache + chunked prefill 组合场景下保持正确性:修复 position_ids 的起点/offset 计算,并在 prefill/mixed 的 FlashAttention 调用中把 cached KV 纳入 cu_seqlens_k/max_seqlen_k,同时补齐从 paged latent cache 读取并与新 token latent 交错的路径。
Changes:
- 修复
get_position_ids_and_mask_encoder_batch在 chunked prefill 场景下 offset 叠加导致的 position_ids 错误,并引入“cached_len 作为 position 起点”的写入逻辑。 - 在
MLAAttentionBackend中新增 fused read-cache + interleave(naive/Triton)并扩展MLAAttentionMetadata,prefill/mixed 使用带 cache 的cu_seqlens_k_with_cache与max_total_kv_len调用 FlashAttention。 - DeepSeek-V3 prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,同时调整 key tensor 的 shape 构造以适配全量 KV token。
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| fastdeploy/model_executor/models/deepseek_v3.py | prefill 分支接入 prefix cache:从 paged latent cache 读取并与新 token latent 交错后再做 KV projection;调整 key shape/赋值逻辑 |
| fastdeploy/model_executor/layers/attention/mla_attention_backend.py | 增加 fused read-cache+interleave(naive/Triton)与 prefix-cache 元数据;prefill/mixed 的 FlashAttention 使用包含 cache 的 seqlens/maxlen;替换 print 为 logger |
| custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu | 修复 chunked prefill + prefix cache 下 position_ids 的 offset 与起点计算逻辑 |
| if need_do_prefill: # max_enc_len_this_time | ||
| key_value = self.kv_b_proj(compressed_kv) | ||
| # Check for prefix cache | ||
| attn_meta = forward_meta.attn_backend.attention_metadata if hasattr(forward_meta, "attn_backend") else None | ||
| has_prefix_cache = False | ||
| total_cached_tokens = 0 | ||
|
|
||
| if attn_meta is not None and isinstance(attn_meta, MLAAttentionMetadata): | ||
| has_prefix_cache = attn_meta.has_prefix_cache | ||
| total_cached_tokens = attn_meta.total_cached_kv_tokens |
| if has_prefix_cache and total_cached_tokens > 0: | ||
| layer_id = self.mla_attn.layer_id if hasattr(self.mla_attn, "layer_id") else 0 | ||
| latent_cache = forward_meta.caches[layer_id] if hasattr(forward_meta, "caches") else None | ||
| if latent_cache is not None: | ||
| block_size = self.mla_attn.block_size if hasattr(self.mla_attn, "block_size") else 64 | ||
| full_compressed_kv, full_k_pe = fused_read_cache_and_interleave( | ||
| latent_cache, | ||
| forward_meta.block_tables, | ||
| compressed_kv, | ||
| key_pe.squeeze(1), | ||
| attn_meta.cu_seqlens_cached_kv, | ||
| forward_meta.cu_seqlens_q, | ||
| self.kv_lora_rank, | ||
| self.qk_rope_head_dim, | ||
| block_size, | ||
| ) |
| bsz = cu_seqlens_cached_kv.shape[0] - 1 | ||
| cu_cached = cu_seqlens_cached_kv.tolist() | ||
| cu_new = cu_seqlens_q.tolist() | ||
| total_cached = int(cu_cached[bsz]) | ||
| total_new = new_compressed_kv.shape[0] | ||
| total_tokens = total_cached + total_new | ||
|
|
||
| full_compressed_kv = paddle.empty([total_tokens, kv_lora_rank], dtype=new_compressed_kv.dtype) | ||
| full_k_pe = paddle.empty([total_tokens, qk_rope_head_dim], dtype=new_k_pe.dtype) | ||
| if total_tokens == 0: | ||
| return full_compressed_kv, full_k_pe | ||
|
|
||
| # block_tables.tolist() is a one-shot D2H; acceptable since host-side loop | ||
| # already requires CPU iteration over total_tokens. | ||
| bt_list = block_tables.tolist() | ||
|
|
||
| is_cached_host = [0] * total_tokens | ||
| src_off_host = [0] * total_tokens | ||
| out_pos = 0 | ||
| for b in range(bsz): | ||
| nc = int(cu_cached[b + 1]) - int(cu_cached[b]) |
| def fused_read_cache_and_interleave(*args, **kwargs): | ||
| """Unified entry. ``FD_MLA_USE_NAIVE=1`` forces the Python reference path.""" | ||
| if os.environ.get("FD_MLA_USE_NAIVE", "0") == "1": |
| // 动态计算当前批次的偏移量。 | ||
| // 每个 batch 只会贡献 encoder_len 或 seq_lens_this_time 中的一个, | ||
| // 而非两者之和(chunked prefill 时 encoder_len > 0 与 decoder_len > 0 | ||
| // 同时成立, | ||
| // 但该 batch 只有 encoder_len 个真实 token)。 | ||
| int offset = 0; | ||
| for (int i = 0; i < tid; i++) { | ||
| offset += seq_lens_encoder[i]; | ||
| if (seq_lens_decoder[i] > 0) { | ||
| if (seq_lens_encoder[i] > 0) { | ||
| offset += seq_lens_encoder[i]; | ||
| } else if (seq_lens_decoder[i] > 0) { | ||
| offset += seq_lens_this_time[i]; | ||
| } | ||
| } |
| """MLA attention forward with prefix cache support.""" | ||
|
|
||
| from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( | ||
| MLAAttentionMetadata, | ||
| fused_read_cache_and_interleave, | ||
| ) |
| cu_total = [0] * (bsz + 1) | ||
| cumsum_cached = 0 | ||
| cumsum_total = 0 | ||
| for i in range(bsz): |
There was a problem hiding this comment.
这里的操作是否可以用自定义算子来处理,取消D2H和H2D 也能简化CPU复杂度
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-09 13:26:38
📋 Review 摘要
PR 概述:为所有 MLA 架构模型(DeepSeek V3 等)补全 chunked prefill + prefix cache 联合支持,修复 position_ids 计算错误、cu_seqlens_k 截断和 latent KV 读取缺失三处 bug。
变更范围:custom_ops/gpu_ops/、fastdeploy/model_executor/layers/attention/mla_attention_backend.py、fastdeploy/model_executor/models/deepseek_v3.py
影响面 Tag:[Models] [OP] [KVCache]
📝 PR 规范检查
PR 标题 [Feature] 为合规官方 Tag,标题本身无需修改。但 ## Usage or Command 与 ## Accuracy Tests 两 section 仅保留 HTML 注释占位符(未填写 N/A),Checklist 全部未勾选,不符合模板要求。
标题建议(可直接复制):
[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models
PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):
## Motivation
MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:
1. `get_position_ids_and_mask_encoder_batch.cu` 中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。
2. `forward_extend` / `forward_mixed` 中 FlashAttention 调用未将 cached KV tokens 纳入 `cu_seqlens_k` 和 `max_seqlen_k`,导致 attention tile 被截断,输出静默损坏。
3. 缺少从 paged latent cache 中读取已缓存 KV 并与新 token KV interleave 的机制。
## Modifications
- `custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu`:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。
- `custom_ops/gpu_ops/get_padding_offset.cu` / `cpp_extensions.cc`:新增 `seq_lens_decoder` 可选参数,使 `cu_seqlens_k` 能包含 cached KV 长度,`cu_seqlens_q` 仍为纯新 token 累积和。
- `fastdeploy/model_executor/layers/attention/mla_attention_backend.py`:
- 新增 `fused_read_cache_and_interleave_naive`(Python 参考实现)和 `fused_read_cache_and_interleave_triton`(Triton 加速版),统一通过 `fused_read_cache_and_interleave` 入口(`FD_MLA_USE_NAIVE=1` 切换)。
- `MLAAttentionMetadata` 新增 `max_seqlen_k` 字段,用于 FlashAttention K 侧长度。
- `forward_extend` / `forward_mixed` 的 FlashAttention 调用使用 `max_seqlen_k` 替代 `max_enc_len_this_time` 作为 K 侧最大长度。
- `fastdeploy/model_executor/models/deepseek_v3.py`:prefill 分支在 chunked prefill 或 prefix cache 开启时,先通过 `fused_read_cache_and_interleave` 读取 cached latent 并拼接再做 KV projection;key tensor shape 调整为包含全量 tokens。
- `fastdeploy/model_executor/pre_and_post_process.py`:`get_padding_offset` 调用增加 `seq_lens_decoder` 参数。
- 测试:新增 `tests/model_executor/test_mla_fused_read_interleave.py`,覆盖 Triton 与 Python 参考路径精度对比;更新 `tests/operators/` 相关测试函数签名。
## Usage or Command
N/A
## Accuracy Tests
N/A
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | mla_attention_backend.py:587 |
max_seqlen_k 计算语义存疑:取 max(max_kv_len_this_time, max_enc_len_this_time) 为两者最大值,而 prefix cache 场景下正确值为各批次 (cached_len + new_len) 之和的最大值 |
| 🟡 建议 | deepseek_v3.py:397 |
forward_meta.block_tables 传入 fused_read_cache_and_interleave 前缺少 None 检查,全为 first-chunk prefill 时可能 crash |
| 🟡 建议 | deepseek_v3.py |
A1 必查项:前向路径改动缺少端到端精度对比(如与 HuggingFace logits 对齐的测试结果) |
| ❓ 疑问 | mla_attention_backend.py:637 |
MLAAttentionMetadata 内注释残留 # forward_meta.cu_seqlens_k: Optional[...] = None,建议清理 |
🟡 mla_attention_backend.py:587 — max_seqlen_k 语义存疑
# 当前代码
metadata.max_seqlen_k = max(metadata.max_kv_len_this_time.item(), metadata.max_enc_len_this_time.item())FlashAttention 的 max_seqlen_k 参数需要等于 max_i(cu_seqlens_k[i+1] - cu_seqlens_k[i]),即各批次实际 KV 序列长度(cached + new)的最大值。当前取 max(max_kv, max_enc) 得到的是两者分别的最大值,而非每批次之和的最大值。若 max_len_tensor_cpu[5] 已预计算了含 cached+new 的总 KV 最大长度,则此写法正确;若它仅为 decode 阶段最大长度,此处将低估 max_seqlen_k,导致 FlashAttention tile 截断(本 PR 修复动机所描述的静默损坏类型)。建议在注释中说明 max_len_tensor_cpu[5] 语义,或从 cu_seqlens_k 直接反推:
metadata.max_seqlen_k = int((forward_meta.cu_seqlens_k[1:] - forward_meta.cu_seqlens_k[:-1]).max())🟡 deepseek_v3.py:397 — block_tables 缺少 None 防御
if self.enable_chunked_prefill or self.enable_prefix_caching:
full_compressed_kv, full_k_pe = fused_read_cache_and_interleave(
forward_meta.caches[self.layer_id],
forward_meta.block_tables, # ← 可能为 None
...
)当开启 chunked prefill / prefix cache 但当前批次全为首轮 prefill(无 cache 命中)时,forward_meta.block_tables 可能为 None 或 [bsz, 0],传入 Triton kernel 会触发段错误。建议加 None guard:
if (self.enable_chunked_prefill or self.enable_prefix_caching) \
and forward_meta.block_tables is not None:
full_compressed_kv, full_k_pe = fused_read_cache_and_interleave(...)❓ mla_attention_backend.py:637 — 注释残留
# forward_meta.cu_seqlens_k: Optional[paddle.Tensor] = None 是被注释掉的伪字段声明,与上方注释语义重复且会误导阅读者,建议删除。
总体评价
PR 思路清晰,三处 MLA prefix cache bug 定位准确;Triton kernel 采用编译期二分搜索+BLOCK_M tiling 设计合理,手工基准测试覆盖完整。建议:① 确认 max_seqlen_k 的计算语义或改用 cu_seqlens_k 直接推导;② 对 block_tables 加 None 防御;③ 在 PR 描述 Accuracy Tests 节补充端到端精度验证结果。
Motivation
MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:
get_position_ids_and_mask_encoder_batch.cu中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。forward_extend/forward_mixed中 FlashAttention 调用未将 cached KV tokens 纳入cu_seqlens_k和max_seqlen_k,导致 attention tile 被截断,输出静默损坏。Modifications
custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。fastdeploy/model_executor/layers/attention/mla_attention_backend.py:fused_read_cache_and_interleave_naive(Python 参考实现)和fused_read_cache_and_interleave_triton(Triton 加速版),统一通过fused_read_cache_and_interleave入口(环境变量FD_MLA_USE_NAIVE=1切换)。MLAAttentionMetadata增加 prefix cache 相关字段(has_prefix_cache、cu_seqlens_cached_kv、cu_seqlens_k_with_cache、max_total_kv_len等)。init_attention_metadata中新增 prefix cache 元数据计算逻辑。forward_extend/forward_mixed中 FlashAttention 调用使用带 cache 的cu_seqlens_k_with_cache和max_total_kv_len。fastdeploy/model_executor/models/deepseek_v3.py:prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,key tensor shape 调整为[full_tokens, heads, qk_head_dim]。Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.