Skip to content

Optimize FlashAttention for M4 Max (20x speedup)#27780

Open
xenova wants to merge 13 commits intomicrosoft:mainfrom
xenova:mha-optimizations
Open

Optimize FlashAttention for M4 Max (20x speedup)#27780
xenova wants to merge 13 commits intomicrosoft:mainfrom
xenova:mha-optimizations

Conversation

@xenova
Copy link
Copy Markdown
Contributor

@xenova xenova commented Mar 20, 2026

MultiHeadAttention
Before: 58.3s
After: 2.89
Speedup: 20x

Description

Motivation and Context

Tested with vision_encoder.onnx for https://huggingface.co/onnx-community/LightOnOCR-2-1B-ONNX

MultiHeadAttention
Before: 58.3s
After: 5.4s
Speedup: 10.8x
@xenova xenova marked this pull request as draft March 20, 2026 05:31
@xenova xenova changed the title Optimize FlashAttention for M4 Max (10.8x speedup) Optimize FlashAttention for M4 Max (12x speedup) Mar 20, 2026
@xenova xenova changed the title Optimize FlashAttention for M4 Max (12x speedup) Optimize FlashAttention for M4 Max (20x speedup) Mar 20, 2026
@xenova xenova marked this pull request as ready for review March 20, 2026 06:26
@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 20, 2026

@guschmue 🙏

@qjia7
Copy link
Copy Markdown
Contributor

qjia7 commented Mar 20, 2026

Awesome! I made a similar change a few days ago to optimize Whisper locally, but your approach is more comprehensive than mine. I just tested it and observed comparable improvements to those in #27781 for Whisper. I’ll go ahead and close that one, and I’m looking forward to seeing this land soon!


// Private memory per lane.
var<private> q_tile : array<q_value_t, head_size_vec>;
var<private> qk_scores : array<q_element_t, max_k_step>;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When max_k_step = 128 (e.g., head_size=32 with f16): this allocates 128 private registers per lane for QK scores. On some GPUs, this may cause register spilling and hurt performance. Have you tested this on less powerful devices, such as Intel Tiger Lake or Qualcomm?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this on less powerful devices, such as Intel Tiger Lake or Qualcomm?

Unfortunately not, I've mainly just tested on my device to be honest. Do you have recommendations or a CI that can help with this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reducing max_k_step to 64 doesn't hurt M4 Max performance

6487515

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found regressions on Qualcomm for phi4. It seems that register spilling happens (2s -> 11s for FlashAttention). Maybe we should keep the original path for qualcomm (not test Tiger Lake yet).
The baseline is as below:
baseline
With this PR:
opt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for testing! I will keep the qualcomm path based on this feedback.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added back qualcomm path. @qjia7 an you do another round of testing? That change doesn't affect current performance on M4.

Run Op Name Count Total Avg Min Max % Total Provider(s)
main MultiHeadAttention 168 693.627 ms 4.129 ms 10.0 us 8.340 ms 80.26% WebGpu
this PR (w/o qualcomm path) MultiHeadAttention 168 77.239 ms 459.8 us 10.0 us 917.0 us 31.06% WebGpu
this PR (w/ qualcomm path) MultiHeadAttention 168 77.270 ms 459.9 us 10.0 us 920.0 us 31.20% WebGpu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified on Qualcomm. The perf is back. Thanks.

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 20, 2026

Awesome! I made a similar change a few days ago to optimize Whisper locally, but your approach is more comprehensive than mine. I just tested it and observed comparable improvements to those in #27781 for Whisper. I’ll go ahead and close that one, and I’m looking forward to seeing this land soon!

Great! 😄 I must admit that the changes I made were very optimized for my M4 Max and this specific vision encoder. But as you mentioned above, it does seem to help with Whisper too.

Also, a lot of the PR diff is removing prefer_subgroupshuffle... which may not be good across other devices. @guschmue lmk what you think!

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 20, 2026

Ran some more benchmarks on some other models.

Model ONNX file Op before after speedup
onnx-community/all-MiniLM-L6-v2-ONNX model.onnx MultiHeadAttention 4.36ms 2.02ms 2.16x
onnx-community/gemma-3-270m-it-ONNX model.onnx GroupQueryAttention 3.65ms 1.63ms 2.24x
onnx-community/LightOnOCR-2-1B-ONNX vision_encoder.onnx MultiHeadAttention 192s 9.3s 20.71x
onnx-community/LightOnOCR-2-1B-ONNX decoder_model_merged.onnx GroupQueryAttention 14.3ms 7.4ms 1.96x

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 21, 2026

been testing more and more... every model sees a 2-3x performance for the MHA nodes. Hoping we can get some benchmarking done on lower-end devices so we can fast-track the PR!

@guschmue
Copy link
Copy Markdown
Contributor

extra cool!

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 23, 2026
@guschmue
Copy link
Copy Markdown
Contributor

I can give it a run on tiger lake in the afternoon.

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 23, 2026

extra cool!
I can give it a run on tiger lake in the afternoon.

Great! Hopefully we see good performance 🤞

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 23, 2026

Based on Guenther's feedback, I updated the implementation so that we only use my optimized branch for apple hardware. Everything else falls back to original implementation. I see the performance increase on my device (m4 max), and other hardware should produce the same benchmarks as before.

@qjia7 @guschmue PTAL 🙏

if (max_k_from_shm >= 64) {
max_k_step_ = 64;
} else if (max_k_from_shm >= 32) {
max_k_step_ = 32;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your current method is to use more registers to improvement the performance. Do you measure that how much perf gap if we use max_k_step_ = 32 instead of max_k_step_ = 64 for M4 Max? And how about using max_k_step_ = 32 plus subgroupShuffle compared with max_k_step_ = 64 for M4 Max? If they can get the similar performance, I prefer we use max_k_step_ = 32 for apple and nvdia, which can help reduce the register pressure (such as M1). My previous machine is NV and see very good improvement for whisper with max_k_step_ = 32.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can test that.

Copy link
Copy Markdown
Contributor Author

@xenova xenova Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, max_k_step_ = 32 has no noticeable performance difference vs. 64

max_k_step_ = 32 plus subgroupShuffle causes significant issues.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird that max_k_step_ = 32 plus subgroupShuffle causes significant issues. Thanks for trying. The latest change looks good to me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah -- weird that it only happens for apple. maybe an upstream implementation issue in dawn?

@guschmue I think we're good to merge? 😇

@guschmue
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@sroussey
Copy link
Copy Markdown
Contributor

@xenova might fix the preprocessor directive whitespace issues in case that is holding this up. definitely looking forward to this improvement!

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Mar 26, 2026

@xenova might fix the preprocessor directive whitespace issues in case that is holding this up. definitely looking forward to this improvement!

We're noticing some regressions for older (~M1) apple hardware... so we're still trying to figure out what the optimal setup looks like.

@kokroo
Copy link
Copy Markdown

kokroo commented Apr 5, 2026

@xenova might fix the preprocessor directive whitespace issues in case that is holding this up. definitely looking forward to this improvement!

We're noticing some regressions for older (~M1) apple hardware... so we're still trying to figure out what the optimal setup looks like.

+1. The fixes don't really spped anything up on my M1 Pro.

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Apr 14, 2026

@qjia7 it would be great to be able to get this working (in such a way that doesn't affect other hardware). Any ideas?

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Apr 14, 2026

noticed again with https://huggingface.co/onnx-community/depth-anything-v2-small-ONNX where my branch is around 6x faster (460ms -> 75ms)

@qjia7
Copy link
Copy Markdown
Contributor

qjia7 commented Apr 14, 2026

@qjia7 it would be great to be able to get this working (in such a way that doesn't affect other hardware). Any ideas?

Could you gather more details from AdapterInfo and further constrain this to M4 Max? If so, we can move forward with landing it. I can also take a deeper look into the regression causes introduced by these changes. Hopefully, we can enable this on more devices—I'm seeing significant gains on my NV device as well.

For example, I can retrieve the adapter info as shown below. I expect you should also be able to distinguish M4 Max using the adapter information.

vendor="nvidia"
architecture="lovelace"
device="NVIDIA RTX 2000 Ada Generation Laptop GPU"
backend_type=4, vendor_id=4318, device_id=10424

@sroussey
Copy link
Copy Markdown
Contributor

I have an M2 Max I can test. What is the quickest way to do so?

@xenova
Copy link
Copy Markdown
Contributor Author

xenova commented Apr 14, 2026

@qjia7 the only useful information I can see is probably architecture : "metal-3" (vendor is apple). everything else appears blank.

@qjia7
Copy link
Copy Markdown
Contributor

qjia7 commented Apr 15, 2026

@qjia7 the only useful information I can see is probably architecture : "metal-3" (vendor is apple). everything else appears blank.

How about providing a WebGPU EP session option, something like

ep.webgpuexecutionprovider.experimentalEnableAggressiveFlashAttention = "1"

Benefits:
Safe to land — off by default, zero regression risk on any device
Device-agnostic — anyone (Apple, NVIDIA, Intel) can opt in and test
Data-driven follow-up — once we collect enough benchmarks across devices, a future PR can auto-enable it for known-good architectures

Note: I see this as a temporary stepping stone, not a permanent solution. For follow-up work:

Root-cause the regressions on some unexpected devices — Having the opt-in flag makes it easy to A/B test on affected machines. We can do deeper analysis on why this regresses (register spilling? workgroup size mismatch? shared memory pressure?). There may also still be room to further optimize the current shader.

File a Dawn bug for richer GPU info — Currently Dawn's AdapterInfo only exposes architecture: "metal-3" for Apple, which isn’t sufficient to distinguish M4 Max from other variants. It would be helpful to ask for more detailed GPU identification. Once Dawn exposes that information, we can follow up with a PR to automatically enable the optimization on the appropriate devices and eventually deprecate the manual opt‑in.

What do you think? @xenova @guschmue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants