Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,11 +1817,19 @@ def get_model(dtype, config):
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
@pytest.mark.parametrize("deterministic", [True, False])
def test_mha_fp8_vs_f16(
dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
dtype,
model,
qkv_format,
input_layernorm,
fp8_dpa_bwd,
RoPE,
is_training,
scaling_mode,
deterministic,
):
"""Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

Expand Down Expand Up @@ -1850,7 +1858,7 @@ def test_mha_fp8_vs_f16(
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
deterministic=deterministic,
)
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
Expand All @@ -1862,7 +1870,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
deterministic=_deterministic,
deterministic=deterministic,
)
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
Expand Down Expand Up @@ -2063,7 +2071,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
@pytest.mark.parametrize("deterministic", [True, False])
def test_dpa_fp8_vs_f16(
dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode, deterministic
):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]

Expand All @@ -2078,7 +2089,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
# config.dropout_p = 0.1

os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"

# Test backend availability
Expand All @@ -2104,7 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
deterministic=deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
Expand All @@ -2115,7 +2125,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=_deterministic,
deterministic=deterministic,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
Expand Down
22 changes: 11 additions & 11 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride);

fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO,
input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view,
input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream,
handle);
bias_type, attn_mask_type, deterministic, &Q_view, &K_view, &V_view, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view,
&dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace,
stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
Expand Down Expand Up @@ -1087,10 +1087,10 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride);

fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view,
&dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace,
stream, handle);
qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, &K_view,
&V_view, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP,
output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
Expand Down Expand Up @@ -1323,9 +1323,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K,
input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP,
output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
Expand Down
41 changes: 23 additions & 18 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1978,13 +1978,13 @@ void fused_attn_fp8_fwd_impl_v1(
void fused_attn_fp8_bwd_impl_v1(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM,
void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO,
void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS,
void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK,
void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP,
void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK,
void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK,
void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type,
cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size,
Expand All @@ -1999,6 +1999,7 @@ void fused_attn_fp8_bwd_impl_v1(
bool is_dropout = (dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
const auto cudnn_runtime_version = cudnnGetVersion();
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
Expand Down Expand Up @@ -2037,7 +2038,7 @@ void fused_attn_fp8_bwd_impl_v1(
0,
0,
true,
false,
deterministic,
qkv_tensor_type,
o_tensor_type,
do_tensor_type,
Expand Down Expand Up @@ -2209,6 +2210,10 @@ void fused_attn_fp8_bwd_impl_v1(
// }
// }

if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
Comment on lines +2213 to +2215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Version check uses 91900 (cuDNN 9.19.0), but related PR #2584 and description mention 9.18.1+ requirement. Should this be 91810 instead?

Suggested change
if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
if (cudnn_runtime_version >= 91810) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}

Is there a specific reason FP8 requires cuDNN 9.19.0+ while FP16/BF16 only needs 9.18.1+?


if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
Expand Down Expand Up @@ -2512,11 +2517,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q,
const Tensor* input_K, const Tensor* input_V, const Tensor* input_O,
const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv,
const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ,
const Tensor* output_dK, const Tensor* output_dV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic,
const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V,
const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M,
const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV,
const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv,
const Tensor* rng_state, Tensor* workspace, cudaStream_t stream,
cudnnHandle_t handle) {
Expand Down Expand Up @@ -2574,11 +2579,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ,
devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS,
devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/common/fused_attn/fused_attn_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv,
const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ,
const Tensor *output_dK, const Tensor *output_dV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M,
const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_fused_attention = False
fused_attention_backend = None
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
if (
fused_attention_backend == FusedAttnBackend["FP8"]
and is_training
and device_compute_capability < (10, 0)
):
logger.debug(
"Disabling FusedAttention for determinism reasons with FP8 on arch < sm100"
)
use_fused_attention = False
fused_attention_backend = None
if (
Expand Down
Loading