Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 11 additions & 80 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr
{
};

template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};

using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));

ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;

// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");

load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;

a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;

b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;

c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));

// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);

// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};

template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
Expand Down Expand Up @@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(
0); // Complete scheduling all pending instruction groups before this point

// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
Expand All @@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr
// sync point.
if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
// This pattern ensures:
// At runtime: All waves synchronize (hardware barrier)
// At compile-time: Instructions after the barrier don't get moved before it
// (scheduling barrier)
__builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in
// the workgroup reach this point
__builtin_amdgcn_sched_barrier(
0); // Prevents instruction reordering across this boundary
}

static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
Expand Down
5 changes: 1 addition & 4 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,6 @@ struct UniversalGemmKernel
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
Expand Down Expand Up @@ -1161,9 +1160,7 @@ struct UniversalGemmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];

constexpr auto scheduler_type =
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(
RunGemm(
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct GemmPipelineProblemBase
static constexpr bool kPadK = Traits::kPadK;

static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;

// In the base situation, the Preshuffle setting should be false.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
};

template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;

Expand Down Expand Up @@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
void* p_smem,
index_t m = 0) const
{
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const BDataType& a) { return a; },
Expand Down