[example] Linalg to XeGPU fused attention implementation.#153
[example] Linalg to XeGPU fused attention implementation.#153charithaintc wants to merge 7 commits into
Conversation
tkarna
left a comment
There was a problem hiding this comment.
Thanks, this is a great addition. Looks good on high level.
I've understood that the end-to-end execution requires additional changes in upstream MLIR that are still pending. Please ping us when this is ready for final review.
| from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect | ||
|
|
||
|
|
||
| class GenerateFusedAttention( |
There was a problem hiding this comment.
nit: Consider changing the name. We use "generate" for methods that generate payload IR from scratch whereas this is a transform that's applied to an existing payload module.
There was a problem hiding this comment.
makes sense. I renamed it to ReplaceWithFusedAttentionOp. open to other suggestions. keep in mind that this will be deprecated when we have the upstream solution.
| Computes fused attention: | ||
| output = softmax(Q @ K^T / sqrt(n_head)) @ V |
There was a problem hiding this comment.
nit: would be good to mention how the generated version differs from standard flash attention, i.e. what is being fused.
There was a problem hiding this comment.
I renamed this to gpu_attention_payload because at payload level there is no fusion. it's just standard attention.
| q_load_op = q_load_ops[0] | ||
| k_load_op = k_load_ops[0] | ||
| v_load_op = v_load_ops[0] | ||
| scale_op = scale_ops[0] | ||
| output_op = output_ops[0] |
There was a problem hiding this comment.
nit: could use extra checks to ensure these are the expected ops
There was a problem hiding this comment.
added more checks.
This example demonstrate how to optimize standard attention kerel written in
linalglevel into the fused attention kernel that can be run gpu.Main steps involved:
Currently this depends on a fix for : #147