Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.8.3 4.0.0b8)
FA_versions=(2.8.3 4.0.0b11)
elif [ $sm_arch -eq 90 ]
then
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8)
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b11)
fi

for fa_version in "${FA_versions[@]}"
Expand Down
32 changes: 26 additions & 6 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,36 @@ def test_dpa_num_splits(dtype, model_configs, model):
@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
def test_dpa_fa4_base(dtype, model_configs, model):
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
"""Test DotProductAttention with FA4: base configs, GQA, num_splits"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# head_dim=256 is supported only on SM100 via FA4's dedicated kernel
# (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10.
# On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and
# the test would silently fall back to another backend — defeating the purpose. Gate
# explicitly so the CI signal is unambiguous.
model_configs_fa4_hdim256 = {
"fa4_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(
get_device_compute_capability() != (10, 0),
reason="FA4 head_dim=256 dedicated kernel is SM100-only.",
)
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_hdim256])
@pytest.mark.parametrize("model", model_configs_fa4_hdim256.keys())
def test_dpa_fa4_hdim256(dtype, model_configs, model):
"""Test DotProductAttention with FA4: head_dim=256 dedicated kernel on SM100"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


Expand All @@ -409,7 +433,6 @@ def test_dpa_fa4_base(dtype, model_configs, model):
@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
Expand All @@ -436,7 +459,6 @@ def test_dpa_fa4_mla(dtype, model_configs, model):
@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
Expand All @@ -460,7 +482,6 @@ def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
Expand All @@ -486,7 +507,6 @@ def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module
flash_attn_func as flash_attn_func_v4,
flash_attn_varlen_func as flash_attn_varlen_func_v4,
_validate_head_dims as _fa4_validate_head_dims,
)

fa_utils.v4_validate_head_dims = _fa4_validate_head_dims
fa_utils.set_flash_attention_4_params()

# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
Expand Down
56 changes: 32 additions & 24 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import math
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
import functools
Expand Down Expand Up @@ -147,8 +147,11 @@ class FlashAttentionUtils:
fa4_version = PkgVersion("0")
use_v4 = False
v4_installation_steps = """\
pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]"""
pip install flash-attn-4==4.0.0b11 nvidia-cutlass-dsl[cu13]"""
v4_warning_printed = False
# Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims
# which raises AssertionError for unsupported (head_dim, head_dim_v) combinations.
v4_validate_head_dims: Callable = None

@staticmethod
def set_flash_attention_version():
Expand Down Expand Up @@ -792,21 +795,25 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_flash_attention_3 = False

if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
# FA4 head dimension support is architecture-dependent
# (matches _validate_head_dims in flash_attn.cute.interface):
# SM90: head_dim <= 256 and head_dim_v <= 256
# SM100/110: head_dim <= 128 and head_dim_v <= 128,
# OR DeepSeek MLA shape (head_dim=192, head_dim_v=128)
# SM80/120: constrained by shared memory (~256 max in practice)
_fa4_hdim_ok = True
if (10, 0) <= device_compute_capability < (12, 0):
_is_standard = head_dim_qk <= 128 and head_dim_v <= 128
_is_deepseek = head_dim_qk == 192 and head_dim_v == 128
_fa4_hdim_ok = _is_standard or _is_deepseek
else:
_fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256
if not _fa4_hdim_ok:
if (
use_flash_attention_4
and FlashAttentionUtils.v4_is_installed
and FlashAttentionUtils.v4_validate_head_dims is not None
):
# Defer to FA4's own _validate_head_dims to keep TE in sync with FA4 supported shapes
# (e.g., (256, 256) on SM100, (192, 128) DeepSeek, (64, 512) MLA-absorbed).
# The function asserts on unsupported combinations; SM80/SM120 have no validation branch
# in FA4 so the call passes through silently for those archs.
_fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size()
try:
# pylint: disable-next=not-callable
FlashAttentionUtils.v4_validate_head_dims(
head_dim_qk,
head_dim_v,
device_compute_capability[0],
_fa4_alignment,
)
except AssertionError:
logger.debug(
"Disabling FlashAttention 4 due to unsupported head dimensions. "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
Expand All @@ -815,13 +822,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
device_compute_capability[0] * 10 + device_compute_capability[1],
)
use_flash_attention_4 = False
# Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128).
# FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2)
# based on Q/K head_dim but reuses it for dV TMEM load atoms. When
# (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned.
# See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
elif (
_fa4_hdim_ok
# Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128) for the
# standard (non-dedicated) kernel path. FlashAttentionBackwardSm100 computes
# dK_reduce_ncol = gcd(32, tile_hdim // 2) based on Q/K head_dim but reuses it for
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (
use_flash_attention_4
and is_training
and head_dim_qk != head_dim_v
and head_dim_qk >= 128
Expand Down
Loading