Skip to content

[example] Linalg to XeGPU fused attention implementation.#153

Open
charithaintc wants to merge 7 commits into
llvm:mainfrom
charithaintc:xegpu_fused_attention_example
Open

[example] Linalg to XeGPU fused attention implementation.#153
charithaintc wants to merge 7 commits into
llvm:mainfrom
charithaintc:xegpu_fused_attention_example

Conversation

@charithaintc
Copy link
Copy Markdown
Contributor

This example demonstrate how to optimize standard attention kerel written in linalg level into the fused attention kernel that can be run gpu.

Main steps involved:

  1. Generate standard attention payload on 4d tensors (batch x head x ctx_len x d_head)
  2. Tile and fuse the outer parallel dims (batch and head)
  3. Vectorize/Bufferize
  4. Use transform extensions to generate the inner tiled reduction loop (Until we have a better solution).
  5. Distribute to GPU workgroups.
  6. Set xegpu layouts and lower to binary.

Currently this depends on a fix for : #147

Copy link
Copy Markdown
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is a great addition. Looks good on high level.

I've understood that the end-to-end execution requires additional changes in upstream MLIR that are still pending. Please ping us when this is ready for final review.

from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect


class GenerateFusedAttention(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider changing the name. We use "generate" for methods that generate payload IR from scratch whereas this is a transform that's applied to an existing payload module.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. I renamed it to ReplaceWithFusedAttentionOp. open to other suggestions. keep in mind that this will be deprecated when we have the upstream solution.

Comment on lines +24 to +25
Computes fused attention:
output = softmax(Q @ K^T / sqrt(n_head)) @ V
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be good to mention how the generated version differs from standard flash attention, i.e. what is being fused.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this to gpu_attention_payload because at payload level there is no fusion. it's just standard attention.

Comment thread lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py Outdated
Comment thread lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py Outdated
Comment on lines +72 to +76
q_load_op = q_load_ops[0]
k_load_op = k_load_ops[0]
v_load_op = v_load_ops[0]
scale_op = scale_ops[0]
output_op = output_ops[0]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could use extra checks to ensure these are the expected ops

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added more checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants