Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
}
}

const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype();
const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype();
const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data;
const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data;

const auto a_dtype = A0_data.dtype;
const auto b_dtype = B0_data.dtype;

Tensor* D0_te = convertNVTETensorCheck(D[0]);
const auto d_dtype = D0_te->dtype();
Expand Down Expand Up @@ -156,6 +159,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
B_use,
D,
static_cast<int>(n),
static_cast<int>(kA),
group_num,
transA_use,
transB_use,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ struct GroupedGemmRunContext {
const NVTETensor* B = nullptr;
NVTETensor* D = nullptr;
int64_t N = 0;
int64_t K = 0;

int group_num = 0;
bool transA = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class GroupedGemmRunner : public RunnerInterface {
}
};

#define MAKE_RUNNER(TileCfg_) \
#define MAKE_FP16_RUNNER(TileCfg_) \
TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \
using Runner = GroupedGemmRunner<AType, \
BType, \
Expand Down Expand Up @@ -231,11 +231,11 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
MAKE_FP16_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);
MAKE_FP16_RUNNER(TileCfg_256x128x64);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
MAKE_FP16_RUNNER(TileCfg_256x128x64_padding);
}
});
});
Expand All @@ -249,7 +249,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
return runner->run(s, ctx);
}

#undef MAKE_RUNNER
#undef MAKE_FP16_RUNNER

} // namespace grouped_gemm
} // namespace transformer_engine
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ enum class GPUArch {
UNKNOWN
};

struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
struct TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;

static constexpr ck_tile::index_t M_Warp = 2;
Expand All @@ -45,13 +45,41 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

struct TileCfg_128x128x128_16x16x128_2x2x1
: TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x256x128_16x16x128_2x2x1_kpad
: TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr bool kPadK = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_kpad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadK = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_npad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadN = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_nkpad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
};

// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile
// configuration due to an unsupported warp GEMM dispatcher configuration.
// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants.
//
// To preserve the existing type name in shared template code, this struct
// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device
// compilation path, effectively reusing those parameters without redefining them.
// inherits from the gfx950-safe 128x128x128 16x16x128 configuration in the
// gfx950 device compilation path, effectively reusing those parameters without
// redefining them.
//
// In all other compilation paths, the struct overrides the relevant fields to
// provide the intended 32x32x16 configuration.
Expand Down Expand Up @@ -261,7 +289,9 @@ class QuantGroupedGemmRunner : public RunnerInterface {
if (descs.empty()) {
return false;
}
return launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg);

const bool launched = launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg);
return launched;
}
};

Expand Down Expand Up @@ -290,6 +320,78 @@ struct FP8TileCfg<GPUArch::GFX950> {
using type = TileCfg_128x128x128_16x16x128_2x2x1;
};

struct FP8GroupedShapeAlignment {
bool all_n_256_aligned = true;
bool all_n_128_aligned = true;
bool all_k_128_aligned = true;
};

static FP8GroupedShapeAlignment get_fp8_grouped_shape_alignment(
const GroupedGemmRunContext& ctx) {
FP8GroupedShapeAlignment alignment;

for (int i = 0; i < ctx.group_num; ++i) {
const transformer_engine::Tensor* const A_te =
transformer_engine::convertNVTETensorCheck(ctx.A[i]);
const transformer_engine::Tensor* const B_te =
transformer_engine::convertNVTETensorCheck(ctx.B[i]);

int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0;

if (ctx.use_a_columnwise_data) {
if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for A in group ", i);
}
} else {
if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A in group ", i);
}
}

if (ctx.use_b_columnwise_data) {
if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i);
}
} else {
if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B in group ", i);
}
}

const int64_t K = ctx.transA ? Ad0 : Ad1;
const int64_t N = ctx.transB ? Bd0 : Bd1;

if (N % 256 != 0) {
alignment.all_n_256_aligned = false;
}
if (N % 128 != 0) {
alignment.all_n_128_aligned = false;
}
if (K % 128 != 0) {
alignment.all_k_128_aligned = false;
}

if (!alignment.all_n_256_aligned &&
!alignment.all_n_128_aligned &&
!alignment.all_k_128_aligned) {
break;
}
}

return alignment;
}

#define MAKE_FP8_RUNNER(TileCfg_) \
using Runner = QuantGroupedGemmRunner<AType, \
BType, \
CType, \
ALayout, \
BLayout, \
CTypeLayout, \
TileCfg_, \
ck_tile::memory_operation_enum::set>; \
runner = std::make_unique<Runner>()

template <GPUArch Arch>
static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
DType b_dtype,
Expand All @@ -299,33 +401,55 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
std::unique_ptr<RunnerInterface> runner = nullptr;

using CTypeLayout = RowMajor;
using TileCfg = typename FP8TileCfg<Arch>::type;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, {
using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;
using Runner = QuantGroupedGemmRunner<AType,
BType,
CType,
ALayout,
BLayout,
CTypeLayout,
TileCfg,
ck_tile::memory_operation_enum::set>;
runner = std::make_unique<Runner>();
});
});

// FP8 grouped GEMM is only compiled for CK's preferred NT presentation:
// transA=false, transB=true
// which maps to:
// ALayout=RowMajor, BLayout=ColMajor.
//
// The caller is responsible for rewriting other FP8 layouts into this form
// using columnwise_data when needed. Reject anything that did not normalize
// successfully so we do not instantiate unreachable/unsupported layout variants.
if (ctx.transA || !ctx.transB) {
return false;
}

using ALayout = RowMajor;
using BLayout = ColMajor;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if constexpr (Arch == GPUArch::GFX950) {
const auto alignment = get_fp8_grouped_shape_alignment(ctx);

if (alignment.all_n_256_aligned) {
if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1);
} else {
MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1_kpad);
}
} else if (alignment.all_n_128_aligned) {
if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1);
} else {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad);
}
} else if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_npad);
} else {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_nkpad);
}
} else {
using TileCfg = typename FP8TileCfg<Arch>::type;
MAKE_FP8_RUNNER(TileCfg);
}
});
});
});
Expand All @@ -334,9 +458,12 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
return false;
}

return runner->run(s, ctx);
const bool ok = runner->run(s, ctx);
return ok;
}

#undef MAKE_FP8_RUNNER

bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
Expand Down
15 changes: 13 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,20 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
#ifdef __HIP_PLATFORM_AMD__
auto A_dt = inputA->data.dtype;
auto B_dt = inputB->data.dtype;
auto effective_dtype = [](const transformer_engine::Tensor* t) {
if (is_fp8_dtype(t->data.dtype)) {
return t->data.dtype;
}
if (t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype)) {
return t->columnwise_data.dtype;
}
return t->data.dtype;
};

auto A_dt = effective_dtype(inputA);
auto B_dt = effective_dtype(inputB);
auto D_dt = OutputD->data.dtype;

return (
(is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt))
) ||
Expand Down
Loading