Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.8.1)
FA_versions=(2.8.3)
elif [ $sm_arch -eq 90 ]
then
FA_versions=(2.7.3 2.8.1 3.0.0b1)
FA_versions=(2.7.3 2.8.3 3.0.0b1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to modify this file as it was for NV upstream QA

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need it but it is part of cherry-picked commit

fi

for fa_version in "${FA_versions[@]}"
Expand Down
8 changes: 7 additions & 1 deletion tests/pytorch/attention/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ def get_tols(config, module, backend, dtype):
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
# With FA on ROCm it may not fit default tolerance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relax of tolerance due to upgrading to 2.8.3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, even though I did see some numerical errors on ROCm 7.2 image with FA 2.8.0 too, 2.8.3 showed higher delta. It should be dues to new CK I suppose, didn't try Triton backend

if IS_HIP_EXTENSION and backend == "FlashAttention":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (4e-2, 4e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
Expand All @@ -389,7 +395,7 @@ def get_tols(config, module, backend, dtype):
# With FA on ROCm it may not fit default tolerance
if IS_HIP_EXTENSION and backend == "FlashAttention":
tols = {
torch.half: (1e-2, 1e-2),
torch.half: (1.2e-2, 1.2e-2),
torch.bfloat16: (1e-1, 1e-1),
}
if module == "DotProductAttention":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -109,7 +109,7 @@ class FlashAttentionUtils:
version = PkgVersion("0")
version_required = PkgVersion("2.1.1")
version_required_blackwell = PkgVersion("2.7.3")
max_version = PkgVersion("2.8.1")
max_version = PkgVersion("2.8.3")
v2_plus = False
v2_1_plus = False
v2_3_plus = False
Expand Down