Optimize FlashAttention for M4 Max (20x speedup)#27780
Optimize FlashAttention for M4 Max (20x speedup)#27780xenova wants to merge 13 commits intomicrosoft:mainfrom
Conversation
MultiHeadAttention Before: 58.3s After: 5.4s Speedup: 10.8x
|
Awesome! I made a similar change a few days ago to optimize Whisper locally, but your approach is more comprehensive than mine. I just tested it and observed comparable improvements to those in #27781 for Whisper. I’ll go ahead and close that one, and I’m looking forward to seeing this land soon! |
|
|
||
| // Private memory per lane. | ||
| var<private> q_tile : array<q_value_t, head_size_vec>; | ||
| var<private> qk_scores : array<q_element_t, max_k_step>; |
There was a problem hiding this comment.
When max_k_step = 128 (e.g., head_size=32 with f16): this allocates 128 private registers per lane for QK scores. On some GPUs, this may cause register spilling and hurt performance. Have you tested this on less powerful devices, such as Intel Tiger Lake or Qualcomm?
There was a problem hiding this comment.
Have you tested this on less powerful devices, such as Intel Tiger Lake or Qualcomm?
Unfortunately not, I've mainly just tested on my device to be honest. Do you have recommendations or a CI that can help with this?
There was a problem hiding this comment.
reducing max_k_step to 64 doesn't hurt M4 Max performance
There was a problem hiding this comment.
Thanks so much for testing! I will keep the qualcomm path based on this feedback.
There was a problem hiding this comment.
added back qualcomm path. @qjia7 an you do another round of testing? That change doesn't affect current performance on M4.
| Run | Op Name | Count | Total | Avg | Min | Max | % Total | Provider(s) |
|---|---|---|---|---|---|---|---|---|
| main | MultiHeadAttention | 168 | 693.627 ms | 4.129 ms | 10.0 us | 8.340 ms | 80.26% | WebGpu |
| this PR (w/o qualcomm path) | MultiHeadAttention | 168 | 77.239 ms | 459.8 us | 10.0 us | 917.0 us | 31.06% | WebGpu |
| this PR (w/ qualcomm path) | MultiHeadAttention | 168 | 77.270 ms | 459.9 us | 10.0 us | 920.0 us | 31.20% | WebGpu |
There was a problem hiding this comment.
Verified on Qualcomm. The perf is back. Thanks.
Great! 😄 I must admit that the changes I made were very optimized for my M4 Max and this specific vision encoder. But as you mentioned above, it does seem to help with Whisper too. Also, a lot of the PR diff is removing |
|
Ran some more benchmarks on some other models.
|
|
been testing more and more... every model sees a 2-3x performance for the MHA nodes. Hoping we can get some benchmarking done on lower-end devices so we can fast-track the PR! |
|
extra cool! |
|
I can give it a run on tiger lake in the afternoon. |
Great! Hopefully we see good performance 🤞 |
|
Based on Guenther's feedback, I updated the implementation so that we only use my optimized branch for apple hardware. Everything else falls back to original implementation. I see the performance increase on my device (m4 max), and other hardware should produce the same benchmarks as before. |
| if (max_k_from_shm >= 64) { | ||
| max_k_step_ = 64; | ||
| } else if (max_k_from_shm >= 32) { | ||
| max_k_step_ = 32; |
There was a problem hiding this comment.
Your current method is to use more registers to improvement the performance. Do you measure that how much perf gap if we use max_k_step_ = 32 instead of max_k_step_ = 64 for M4 Max? And how about using max_k_step_ = 32 plus subgroupShuffle compared with max_k_step_ = 64 for M4 Max? If they can get the similar performance, I prefer we use max_k_step_ = 32 for apple and nvdia, which can help reduce the register pressure (such as M1). My previous machine is NV and see very good improvement for whisper with max_k_step_ = 32.
There was a problem hiding this comment.
sure I can test that.
There was a problem hiding this comment.
okay, max_k_step_ = 32 has no noticeable performance difference vs. 64
max_k_step_ = 32 plus subgroupShuffle causes significant issues.
There was a problem hiding this comment.
Weird that max_k_step_ = 32 plus subgroupShuffle causes significant issues. Thanks for trying. The latest change looks good to me.
There was a problem hiding this comment.
yeah -- weird that it only happens for apple. maybe an upstream implementation issue in dawn?
@guschmue I think we're good to merge? 😇
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
@xenova might fix the preprocessor directive whitespace issues in case that is holding this up. definitely looking forward to this improvement! |
We're noticing some regressions for older (~M1) apple hardware... so we're still trying to figure out what the optimal setup looks like. |
+1. The fixes don't really spped anything up on my M1 Pro. |
|
@qjia7 it would be great to be able to get this working (in such a way that doesn't affect other hardware). Any ideas? |
|
noticed again with https://huggingface.co/onnx-community/depth-anything-v2-small-ONNX where my branch is around 6x faster (460ms -> 75ms) |
Could you gather more details from AdapterInfo and further constrain this to M4 Max? If so, we can move forward with landing it. I can also take a deeper look into the regression causes introduced by these changes. Hopefully, we can enable this on more devices—I'm seeing significant gains on my NV device as well. For example, I can retrieve the adapter info as shown below. I expect you should also be able to distinguish M4 Max using the adapter information. |
|
I have an M2 Max I can test. What is the quickest way to do so? |
|
@qjia7 the only useful information I can see is probably |
How about providing a WebGPU EP session option, something like Benefits: Note: I see this as a temporary stepping stone, not a permanent solution. For follow-up work: Root-cause the regressions on some unexpected devices — Having the opt-in flag makes it easy to A/B test on affected machines. We can do deeper analysis on why this regresses (register spilling? workgroup size mismatch? shared memory pressure?). There may also still be room to further optimize the current shader. File a Dawn bug for richer GPU info — Currently Dawn's AdapterInfo only exposes architecture: "metal-3" for Apple, which isn’t sufficient to distinguish M4 Max from other variants. It would be helpful to ask for more detailed GPU identification. Once Dawn exposes that information, we can follow up with a PR to automatically enable the optimization on the appropriate devices and eventually deprecate the manual opt‑in. |


MultiHeadAttention
Before: 58.3s
After: 2.89
Speedup: 20x
Description
Motivation and Context
Tested with vision_encoder.onnx for https://huggingface.co/onnx-community/LightOnOCR-2-1B-ONNX