Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@
((0 < args.window_size_left) or (0 < args.window_size_right));
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
(((traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
(traits.qscale_type == quant_scale_enum::no_scale)) or
((traits.data_type.compare("fp8bf16") == 0) and
(traits.qscale_type == quant_scale_enum::pertensor))) and
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
return fmha_fwd_v3(traits, args, config);
Expand Down Expand Up @@ -1048,6 +1050,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
elif dtype in cls._DT_FP8BF16:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
return result

@classmethod
Expand Down Expand Up @@ -1085,6 +1091,15 @@ def get_pipelines(
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip

elif dtype in cls._DT_FP8BF16:
# no need lse/dropout kernels
# qr_async_trload_v3 only supports (generic) causal mask
for logits, qscale, mask in itertools.product(
["t", "f"],
["no", "pertensor"],
["no", "causal"],
):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines


Expand Down
6 changes: 6 additions & 0 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqstart_q_ptr,
Expand Down Expand Up @@ -764,6 +767,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
Expand Down
111 changes: 97 additions & 14 deletions include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
Expand All @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;

using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
Expand Down Expand Up @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel
float logits_soft_cap_rcp;
};

struct FmhaFwdCommonQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};

struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
Expand All @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
Expand All @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
Expand Down Expand Up @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
Expand All @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
Expand All @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
Expand Down Expand Up @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
Expand All @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
Expand Down Expand Up @@ -640,32 +675,80 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();

const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
return kargs.scale_s * q_descale * k_descale;
}
else
{
return kargs.scale_s;
}
}();

AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
}
}();

BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};

auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
float scale_o = v_descale / scale_p;

auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
return ck_tile::scales{scale_o};
}();

return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
o_acc_element_func,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}
}();

// O DRAM and O DRAM window
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ struct BlockFmhaFwdV3Pipeline
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout &&
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
!kSkipMinSeqlenQ),
"enable unsupported features");

Expand Down Expand Up @@ -437,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");

static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
// static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<kM0, kN0>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;

constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
constexpr auto warp_gemm = [] {
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
Expand Down
7 changes: 7 additions & 0 deletions include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;

using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
Expand Down
2 changes: 2 additions & 0 deletions include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, true> { u
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<EDouble>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
Expand Down
Loading