Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677
Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677aditya-dl wants to merge 3 commits intoROCm:developfrom
Conversation
|
Unit tests need to be added and the CI failure fixed. |
There was a problem hiding this comment.
Pull request overview
This PR fixes two root causes of garbage output in Qwen1.5-architecture models during FP16 inference: (1) FP16 overflow in the dot→softmax attention chain, and (2) a buffer overflow in concat_past_present during prompt processing when the sequence length exceeds the past cache size.
Changes:
- Extends
find_softmax_base_opsinrewrite_reduce.cppto walk backward from softmax throughmul/where/broadcast/convertto find a feedingdotinstruction, upcasting the entire range to FP32 (with bool inputs excluded). - Fixes
concat_past_presentbuffer sizing across the operator definition, GPU lowering, JIT compiler, and GPU kernel so that the output buffer is properly sized whensequence_length > past_cache_sequence_length.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/rewrite_reduce.cpp |
Adds backward walk from softmax to dot for FP32 upcast range extension; skips bool inputs in conversion |
src/include/migraphx/op/concat_past_present.hpp |
Updates compute_shape to return larger shape when needed; uses std::max for present_buffer_sequence_length |
src/targets/gpu/lowering.cpp |
Allocates properly-sized GPU buffer when output shape exceeds past cache shape |
src/targets/gpu/jit/concat_past_present.cpp |
Adjusts JIT compiler output shape to match larger buffer when needed |
src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp |
GPU kernel uses max(past_seq, seq_len) for present buffer sequence length |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
I think you should put the softmax change in a separate PR as I dont think we will merge the concat_past_present change. |
6f03700 to
9da5d9f
Compare
|
@pfultz2 Updated this PR with only the softmax change. |
|
The CI failures in The fix gates the backward walk on the presence of where in the chain. Without where, the dot stays in FP16 and CK fused attention handles precision internally. With where (all GQA models using the GroupQueryAttention ONNX op — Qwen, Llama, Phi, DeepSeek), the dot is upcast to FP32. The where ops block propagate_precision from merging converts (multi-input op), which prevents fuse_attention from matching, so ops run as separate FP32 kernels — fixing the FP16 overflow. Added a new test (softmax_dot_no_where_preserves_half) verifying the dot is NOT upcast when where is absent. |
Motivation
Qwen1.5-architecture models produce garbage output when running FP16 inference through MIGraphX. Two root causes were identified:
Technical Details
FP16 overflow fix: Extends
find_softmax_base_opsto walk backwards through the attention chain (mul, where, broadcast, convert) to find the feeding dot instruction. The entiredot-to-softmax range is upcast to FP32, preventing overflow in attention score computation. Bool-type inputs (where conditions) are excluded from conversion.Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable