[BugFix][KSM] Fix sampling_mask reordering in recover_batch_index_for…#7773
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 修复在开启 enable_keep_sampling_mask (KSM) 且启用 pd_reorder 时,recover_batch_index_for_sampler_output 未对 sampling_mask 做批次顺序恢复的问题,避免 sampling_mask 与 logprobs 错位导致 logprob 归一化使用错误候选集,进而引发下游训练异常(如 KL 爆炸)。
Changes:
- 在
recover_batch_index_for_sampler_output中补齐sampling_mask的重排逻辑,使其与sampled_token_ids/logprobs_tensors等字段保持同一请求对齐。 - 新增单元测试文件,覆盖
sampling_mask=None、正常重排、恒等映射、禁用pd_reorder、以及 tail 元素保持原位等场景。
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| fastdeploy/worker/input_batch.py | 为 recover_batch_index_for_sampler_output 增加 sampling_mask 的重排逻辑,以修复 KSM+pd_reorder 下的错位问题。 |
| tests/worker/test_recover_batch_index_sampling_mask.py | 新增针对 sampling_mask 重排行为的单元测试覆盖。 |
| sort_len = len(src_order) | ||
| real_sampling_mask = [None] * len(sampling_mask) | ||
| for i in range(sort_len): | ||
| real_sampling_mask[i] = sampling_mask[src_order[i]] | ||
| for i in range(sort_len, len(sampling_mask)): | ||
| real_sampling_mask[i] = sampling_mask[i] | ||
| sampler_output.sampling_mask = real_sampling_mask |
| def _make_sampler_output(batch_size, with_sampling_mask=True): | ||
| """Create a minimal mock SamplerOutput for testing reorder logic.""" | ||
| so = Mock() | ||
| so.sampled_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) | ||
| so.logprobs_tensors = Mock() | ||
| so.logprobs_tensors.logprob_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) | ||
| so.logprobs_tensors.logprobs = paddle.arange(batch_size, dtype="float32").unsqueeze(1) | ||
| so.logprobs_tensors.selected_token_ranks = paddle.zeros([batch_size, 1], dtype="int64") |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-11 16:35:43
📋 Review 摘要
PR 概述:修复 KSM(keep_sampling_mask)+ pd_reorder 开启时,recover_batch_index_for_sampler_output 遗漏对 sampling_mask 字段重排的 bug,避免 logprob 计算错位导致 KL 爆炸
变更范围:fastdeploy/worker/input_batch.py、tests/worker/
影响面 Tag:[BugFix] [PD Disaggregation]
📝 PR 规范检查
发现以下问题:
- 非官方 Tag:标题中
[KSM]不在官方 Tag 列表中(见 checklist §D1) - Cherry-Pick 格式缺失:目标分支为
release/2.6(非develop),标题格式应为[Cherry-Pick][BugFix] ... (#原PR号) - Checklist 末项未勾选:提交到 release 分支的 cherry-pick 确认项未勾选
标题建议(可直接复制):
[Cherry-Pick][BugFix] Fix sampling_mask reordering in recover_batch_index_for_sampler_output (#原PR号)
PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):
## Motivation
开启 `enable_keep_sampling_mask` (KSM) + `pd_reorder` 时,`recover_batch_index_for_sampler_output` 对 `logprobs_tensors`、`sampled_token_ids` 等字段做了批次顺序恢复,但遗漏了 `sampling_mask`,导致候选词表与 logprobs 错位配对,logz 归一化使用了错误请求的候选集,产生异常 logprob 值,最终导致下游 RL 训练 KL 爆炸(最高达 64048)。
## Modifications
- `fastdeploy/worker/input_batch.py`:在 `recover_batch_index_for_sampler_output` 函数末尾,补充对 `sampling_mask`(`List[np.ndarray]`)的重排逻辑,使用与其他字段相同的 `src_order` 排列,保持 `sampling_mask[i]` 与 `logprobs[i]` 始终属于同一请求
- `tests/worker/test_recover_batch_index_sampling_mask.py`:新增 5 个单元测试,覆盖 None 值、正常重排、恒等映射 noop、禁用 pd_reorder、tail 元素原位保持等场景
## Usage or Command
N/A
## Accuracy Tests
N/A
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 📝 PR 规范 | 标题/Checklist | 非官方 Tag [KSM];目标分支 release/2.6 缺少 [Cherry-Pick] 格式;cherry-pick 来源确认 Checklist 未勾选 |
总体评价
代码逻辑正确,单元测试覆盖全面(5 个场景),有效修复了 PD 分离 + KSM 场景下 sampling_mask 重排遗漏的 bug。仅需调整 PR 标题并确认 cherry-pick 来源流程。
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览当前有 1 个 required 失败任务(Approval — 缺少 FastDeploy RD 审批),另有 2 个 required 运行中、3 个 required 等待中,请关注并处理。
2 任务状态汇总2.1 Required任务 : 4/10 通过
2.2 可选任务 — 23/26 通过
3 失败详情(仅 required)Approval — 流程问题(审批)(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 请 qingqing01/jiangjiajun/heavengate 中任一人审批此 PR 链接: 查看日志 |
Motivation
开启
enable_keep_sampling_mask(KSM) +pd_reorder时,recover_batch_index_for_sampler_output对logprobs_tensors、sampled_token_ids等字段做了批次顺序恢复,但遗漏了sampling_mask,导致候选词表与 logprobs 错位配对,logz 归一化使用了错误请求的候选集,产生异常 logprob 值,最终导致下游 RL 训练 KL 爆炸(最高达 64048)。Modifications
fastdeploy/worker/input_batch.py:在recover_batch_index_for_sampler_output函数末尾,补充对sampling_mask(List[np.ndarray])的重排逻辑,使用与其他字段相同的src_order排列,保持sampling_mask[i]与logprobs[i]始终属于同一请求tests/worker/test_recover_batch_index_sampling_mask.py:新增 5 个单元测试,覆盖 None 值、正常重排、恒等映射 noop、禁用 pd_reorder、tail 元素原位保持等场景Usage or Command
N/A
Accuracy Tests
N/A
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.