Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
516 changes: 177 additions & 339 deletions transformer_engine/common/fused_attn/fused_attn.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1333,4 +1333,138 @@ void fused_attn_arbitrary_seqlen_bwd(
NVTE_ERROR("Unexpected workspace_size.");
}
}

namespace {
// Probe-time defaults for runtime-only quantities the router doesn't see (paged-KV dims,
// ragged max-tokens, bias dims). These produce a graph whose support surface matches the
// real executor's: for non-paged / non-ragged paths these are unused inside the impl;
// for ragged-THD we rebind to worst-case bounds; for paged we use 1 page of full s_kv per
// batch (= same dims as non-paged), so cuDNN-FE applies the paged-attention support rules.
struct ProbeDims {
int64_t max_b;
int64_t max_t_q;
int64_t max_t_kv;
int64_t num_pages_k;
int64_t num_pages_v;
int64_t page_size_k;
int64_t page_size_v;
int64_t max_pages_per_seq_k;
int64_t max_pages_per_seq_v;
int64_t bias_b;
int64_t bias_h;
int64_t bias_sq;
int64_t bias_skv;
};

ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_seqlen_q,
int64_t max_seqlen_kv, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type) {
const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);

ProbeDims d{};
d.max_b = (is_ragged_q || is_ragged_kv) ? batch : 0;
d.max_t_q = is_ragged_q ? batch * max_seqlen_q : 0;
d.max_t_kv = is_ragged_kv ? batch * max_seqlen_kv : 0;
d.num_pages_k = is_paged_kv ? batch : 0;
d.num_pages_v = is_paged_kv ? batch : 0;
d.page_size_k = is_paged_kv ? max_seqlen_kv : 0;
d.page_size_v = is_paged_kv ? max_seqlen_kv : 0;
d.max_pages_per_seq_k = is_paged_kv ? 1 : 0;
d.max_pages_per_seq_v = is_paged_kv ? 1 : 0;
d.bias_b = has_bias ? batch : 0;
d.bias_h = has_bias ? num_attn_heads : 0;
d.bias_sq = has_bias ? max_seqlen_q : 0;
d.bias_skv = has_bias ? max_seqlen_kv : 0;
return d;
}
} // namespace

cudnn_frontend::error_t is_supported_f16_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training,
bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) {
const ProbeDims d =
compute_probe_dims(static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_kv),
qkv_layout, bias_type);
const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout);

size_t workspace_size = 0;
try {
fused_attn::fused_attn_arbitrary_seqlen_fwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k,
d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, d.max_pages_per_seq_v,
d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, return_max_logit,
/*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr,
/*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr,
/*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr,
/*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr,
/*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr,
/*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype),
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return {cudnn_frontend::error_code_t::OK, ""};
} catch (const std::exception &e) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()};
} catch (...) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED,
"is_supported_f16_fwd: unknown failure"};
}
}

cudnn_frontend::error_t is_supported_f16_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) {
const ProbeDims d =
compute_probe_dims(static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_kv),
qkv_layout, bias_type);
const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout);
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;

size_t workspace_size = 0;
try {
fused_attn::fused_attn_arbitrary_seqlen_bwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.bias_b, d.bias_h,
d.bias_sq, d.bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format,
dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr,
/*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr,
/*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr,
/*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr,
/*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr,
/*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr,
/*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr,
/*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype),
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return {cudnn_frontend::error_code_t::OK, ""};
} catch (const std::exception &e) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()};
} catch (...) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED,
"is_supported_f16_bwd: unknown failure"};
}
}

} // namespace transformer_engine
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_

#include <cudnn.h>
#include <cudnn_frontend.h>

#include "common/common.h"
#include "transformer_engine/fused_attn.h"
Expand Down Expand Up @@ -47,6 +48,27 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans ->
// check_support -> build_plans) for an F16/BF16 forward graph with the given configuration.
// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK,
// the built graph is inserted into the same thread-local cache used by
// fused_attn_arbitrary_seqlen_fwd_impl, so the executor cache-hits on matching descriptors.
// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message.
cudnn_frontend::error_t is_supported_f16_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training,
bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle);

// Probe: same as above for the F16/BF16 backward graph.
cudnn_frontend::error_t is_supported_f16_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle);

} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
101 changes: 101 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2991,4 +2991,105 @@ void fused_attn_fp8_bwd(
return;
}
}

cudnn_frontend::error_t is_supported_fp8_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode,
cudnnHandle_t handle) {
// FP8 fwd impl rejects any qkv_format other than BSHD/SBHD/BHSD with NVTE_ERROR; mirror that
// here so the probe returns a typed rejection instead of catching the throw.
const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD &&
qkv_format != NVTE_QKV_Format::NVTE_BHSD) {
return {cudnn_frontend::error_code_t::INVALID_VALUE,
"FP8 fused attention only supports BSHD/SBHD/BHSD layouts."};
}
size_t workspace_size = 0;
try {
fused_attn::fused_attn_fp8_fwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout,
qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr,
/*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr,
/*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr,
/*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr,
/*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr,
/*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr,
/*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), get_cudnn_fe_dtype(o_dtype),
scaling_mode,
/*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return {cudnn_frontend::error_code_t::OK, ""};
} catch (const std::exception& e) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()};
} catch (...) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED,
"is_supported_fp8_fwd: unknown failure"};
}
}

cudnn_frontend::error_t is_supported_fp8_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle) {
const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD &&
qkv_format != NVTE_QKV_Format::NVTE_BHSD) {
return {cudnn_frontend::error_code_t::INVALID_VALUE,
"FP8 fused attention only supports BSHD/SBHD/BHSD layouts."};
}
// For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type
// (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE).
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(q_dtype);
const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype);
const cudnn_frontend::DataType_t do_t = o_t;
const cudnn_frontend::DataType_t dqkv_t = qkv_t;
size_t workspace_size = 0;
try {
fused_attn::fused_attn_fp8_bwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout,
/*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr,
/*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr,
/*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr,
/*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDescaleQ=*/nullptr,
/*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleO=*/nullptr,
/*devPtrDescaledO=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrDescaledP=*/nullptr,
/*devPtrScaleS=*/nullptr, /*devPtrScaledP=*/nullptr, /*devPtrScaledQ=*/nullptr,
/*devPtrScaledK=*/nullptr, /*devPtrScaledV=*/nullptr, /*devPtrAmaxdP=*/nullptr,
/*devPtrAmaxdQ=*/nullptr, /*devPtrAmaxdK=*/nullptr, /*devPtrAmaxdV=*/nullptr,
/*devPtrQ_t=*/nullptr, /*devPtrK_t=*/nullptr, /*devPtrdO_f16=*/nullptr,
/*devPtrdO_t=*/nullptr, /*devPtrDescaleQ_t=*/nullptr, /*devPtrDescaleK_t=*/nullptr,
/*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr,
/*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr,
/*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode,
/*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return {cudnn_frontend::error_code_t::OK, ""};
} catch (const std::exception& e) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()};
} catch (...) {
return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED,
"is_supported_fp8_bwd: unknown failure"};
}
}

} // namespace transformer_engine
25 changes: 25 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
* \brief Functions for fused attention for FP8 with seqlen <= 512
*/

#include <cudnn_frontend.h>

#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

Expand Down Expand Up @@ -39,4 +41,27 @@ void fused_attn_fp8_bwd(
const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans ->
// check_support -> build_plans) for an FP8 forward graph with the given configuration.
// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK,
// the built graph is inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl.
// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message.
cudnn_frontend::error_t is_supported_fp8_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode,
cudnnHandle_t handle);

// Probe: same as above for the FP8 backward graph.
cudnn_frontend::error_t is_supported_fp8_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle);
>>>>>>> c9006435 (refactor nvte_get_fused_attn_backend with FE calls)
} // namespace transformer_engine
Loading
Loading