Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.graph import is_graph_capturing

# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
# Try to import Flash Attention v2
try:
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
pass # only print warning if use_flash_attention_2 = True in get_attention_backend
else:
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
Expand Down Expand Up @@ -130,12 +130,16 @@
),
fa_utils.version,
)

# Try to import Flash Attention v3
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError:
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None
flash_attn_with_kvcache_v3 = None
_flash_attn_fwd_v3 = None
_flash_attn_bwd_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
Expand All @@ -150,6 +154,25 @@

fa_utils.set_flash_attention_3_params()

# Try to import Flash Attention v4
try:
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-cute"))
except PackageNotFoundError:
flash_attn_func_v4 = None
flash_attn_varlen_func_v4 = None
flash_attn_with_kvcache_v4 = None
_flash_attn_fwd_v4 = None
_flash_attn_bwd_v4 = None
# pass # only print warning if use_flash_attention_4 = True in get_attention_backend
else:
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_v4
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4

# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
fa_utils.set_flash_attention_4_params()

# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"

Expand Down Expand Up @@ -859,6 +882,9 @@ def forward(
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
use_flash_attn_4 = False
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suffix cute is added in get_attention_backend because FA4 is released with the package name flash-attn-cute and version starting from 0.1.0. We need to add the ".cute" suffix to the version number to distinguish.

use_flash_attn_4 = True
if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
Expand Down Expand Up @@ -913,9 +939,13 @@ def forward(
# | | bshd/sbhd/thd + padding
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = (
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
) # pylint: disable=possibly-used-before-assignment
func = None
if use_flash_attn_4:
func = flash_attn_func_v4
elif use_flash_attn_3:
func = flash_attn_func_v3
else:
func = flash_attn_func
else:
if not use_flash_attn_3:
func = flash_attn_varlen_func
Expand All @@ -928,7 +958,24 @@ def forward(
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
if use_flash_attn_4:
fa_4_optional_forward_kwargs = {
# "window_size": window_size,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default window_size = (-1, 0) doesn't mean no sliding window for FA4.

"num_splits": num_splits,
}
if inference_params is None:
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
output = func(
query_layer,
key_layer,
value_layer,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_4_optional_forward_kwargs,
)
if isinstance(output, (List, Tuple)):
output = output[0]
elif not use_flash_attn_3:
fa_optional_forward_kwargs = {}
if fa_utils.v2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
Expand Down
Loading