Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.

#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
Expand Down Expand Up @@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
#else

static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Expand All @@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
#endif
// clang-format on

Expand All @@ -182,12 +184,14 @@ int main(int argc, char* argv[])
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t N = 1536;
ck::index_t K = 4096;
// ck::index_t N = 4096;
// ck::index_t K = 6144;
// ck::index_t N = 128;
// ck::index_t K = 512;
ck::index_t experts = 8;
ck::index_t topk = 2;
ck::index_t experts = 16;
ck::index_t topk = 8;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 208;
Expand All @@ -196,9 +200,9 @@ int main(int argc, char* argv[])
// ck::index_t sorted_tile_num = 259;
// ck::index_t valid_tile_num = 256;
// ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 2;
ck::index_t valid_tile_num = 2;
ck::index_t tokens = 32;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 16;
ck::index_t tokens = 4;
#else
// deepseek
ck::index_t N = 2048;
Expand All @@ -209,7 +213,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
ck::index_t KBatch = 6;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
typename LDSTypeB = ComputeTypeB,
bool NonTemporalLoadB = false>
struct DeviceMoeGemmBlockScale
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
Expand Down Expand Up @@ -163,7 +164,8 @@ struct DeviceMoeGemmBlockScale
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
LDSTypeB,
NonTemporalLoadB>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
typename LDSTypeB = BDataType,
bool NonTemporalLoadB = false>
struct GridwiseMoeGemmBlockScale
{
using AScaleType = float;
Expand Down Expand Up @@ -1202,6 +1203,13 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
: AmdBufferCoherenceEnum::DefaultCoherence;
#else
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
#endif
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
Expand Down Expand Up @@ -1300,15 +1308,16 @@ struct GridwiseMoeGemmBlockScale

const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());

const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());

// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
Expand Down Expand Up @@ -1465,9 +1474,11 @@ struct GridwiseMoeGemmBlockScale
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
Expand All @@ -1485,9 +1496,10 @@ struct GridwiseMoeGemmBlockScale
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
Expand Down Expand Up @@ -1958,6 +1970,13 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
: AmdBufferCoherenceEnum::DefaultCoherence;
#else
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
#endif
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
Expand Down Expand Up @@ -2054,15 +2073,16 @@ struct GridwiseMoeGemmBlockScale

const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());

const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());

// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
Expand Down Expand Up @@ -2227,9 +2247,11 @@ struct GridwiseMoeGemmBlockScale
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
Expand All @@ -2247,9 +2269,10 @@ struct GridwiseMoeGemmBlockScale
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
Expand Down