[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 support in FlashAttention 4 by delegating head-dimension validation to FA4's own
Confidence Score: 4/5Safe to merge once the The import of transformer_engine/pytorch/attention/dot_product_attention/backends.py — the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[backends.py module load] --> B{FA4 package installed?}
B -- No --> C[flash_attn_func_v4 = None]
B -- Yes --> D[import flash_attn_func, flash_attn_varlen_func,\n_validate_head_dims]
D -- ImportError if symbol missing --> E[Unhandled ImportError breaks backends.py load]
D -- OK --> F[v4_validate_head_dims = _fa4_validate_head_dims]
F --> G[get_attention_backend called]
G --> H{use_flash_attention_4 and v4_validate_head_dims != None?}
H -- No --> I[Skip FA4 head-dim validation]
H -- Yes --> J[Call v4_validate_head_dims]
J -- AssertionError --> K[use_flash_attention_4 = False]
J -- OK --> L{SM100 MLA workaround needed?}
L -- Yes misaligned --> M[use_flash_attention_4 = False]
L -- No --> N[FA4 selected]
|
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: