Fix attention for non-standard literal#4877
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes an attention fusion regression where capturing a constant-folded, non-standard-strided bias literal into the attention group causes MLIR lowering to emit a migraphx.literal with non-standard strides (rejected by rocMLIR). The change keeps such literals outside the group so they enter as regular inputs, while preserving existing inlining behavior for standard-strided literals.
Changes:
- Update attention fusion’s constant-capture logic to skip non-standard
@literalinstructions when expanding the attention group. - Add a regression test covering the failing “transposed/non-standard literal bias” scenario and a companion test ensuring standard literal bias is still inlined.
- Ensure the expected grouped program structure matches the intended behavior via
sort()equivalence.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/fuse_attention.cpp |
Prevents non-standard @literal nodes from being captured into attention groups to avoid MLIR literal lowering failures. |
test/fuse_attention.cpp |
Adds regression/behavioral tests for non-standard vs standard bias literal handling in attention fusion. |
| // bias) outside the group so they enter as a parameter | ||
| // instead, where mlir_contiguous + adjust_param_shapes can | ||
| // normalise the layout. |
97287d6 to
a51abb0
Compare
| // Leave non-standard literals (e.g. constant-folded transposed | ||
| // bias) outside the group so they enter as a parameter | ||
| // instead, where mlir_contiguous + adjust_param_shapes can | ||
| // normalise the layout. |
There was a problem hiding this comment.
Just remove this comment. There so many things incorrect in this comment it will just cause more confusion.
| // instead, where mlir_contiguous + adjust_param_shapes can | ||
| // normalise the layout. | ||
| if(input->name() == "@literal" and not input->get_shape().standard()) | ||
| continue; |
There was a problem hiding this comment.
This doesnt seem like the correct fix. We need to skip literals that are not iota literals, but you are not checking if its an iota literal.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4877 +/- ##
===========================================
+ Coverage 92.85% 92.87% +0.02%
===========================================
Files 584 585 +1
Lines 30147 30123 -24
===========================================
- Hits 27992 27976 -16
+ Misses 2155 2147 -8
🚀 New features to boost your workflow:
|
Motivation
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable