Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
}
else
{
std::cout << "invoke_gemm non-persistent" << std::endl;
ave_time = Invoker::template gemm<GemmConfig,
ADataType,
BDataType,
Expand Down Expand Up @@ -219,12 +220,16 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
std::cout << __func__ << std::endl;
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;

ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");

std::cout << "M: " << M << " N: " << N << " K: " << K << std::endl;


ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
Expand Down Expand Up @@ -339,6 +344,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

// layout
std::cout << "a_layout: " << a_layout.name << " b_layout: " << b_layout.name << " c_layout: " << c_layout.name << std::endl;

float ave_time = invoke_gemm<GemmConfig,
Invoker,
ADataType,
Expand Down
1 change: 1 addition & 0 deletions example/ck_tile/03_gemm/run_gemm_example_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
std::cout << __func__ << std::endl;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
template <template <typename PrecType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
//std::cout << __func__ << std::endl;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
Expand All @@ -23,6 +24,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)

if(data_type == "fp16")
{
std::cout << "fp16" << std::endl;
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
Expand Down Expand Up @@ -109,6 +111,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)

int main(int argc, char* argv[])
{
//std::cout << __func__ << std::endl;
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);

Expand All @@ -120,7 +123,7 @@ int main(int argc, char* argv[])
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
return !run_gemm_example<GemmConfigMemoryInterwave>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ struct BlockUniversalGemmAsBsCr
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
printf("GemmTraits::KIterPerWarp: %d\n", GemmTraits::KIterPerWarp);
printf("MIterPerWarp: %d\n", MIterPerWarp);
printf("NIterPerWarp: %d\n", NIterPerWarp);
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
printf("kIter: %d\n", kIter);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
Expand Down Expand Up @@ -312,6 +316,13 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow&, const BSmemBlockWindow&, bool_constant<ALoadTranspose> = {}, bool_constant<BLoadTranspose> = {})\n");
printf("KIterPerWarp: %d\n", KIterPerWarp);
printf("MIterPerWarp: %d\n", MIterPerWarp);
printf("NIterPerWarp: %d\n", NIterPerWarp);
}
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
Expand Down Expand Up @@ -443,6 +454,15 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("Interwave");
printf("operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow&, const BSmemBlockWindow&, bool_constant<ALoadTranspose> = {}, bool_constant<BLoadTranspose> = {})\n");
printf("KRepeat: %d\n", KRepeat);
printf("KInnerLoopIter: %d\n", KInnerLoopIter);
printf("MIterPerWarp: %d\n", MIterPerWarp);
printf("NIterPerWarp: %d\n", NIterPerWarp);
}
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
Expand Down
8 changes: 8 additions & 0 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,10 @@ struct UniversalGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("RunGemm\n");
// }
// Create block windows using specialized methods
const auto& as_block_window =
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
Expand Down Expand Up @@ -1130,6 +1134,10 @@ struct UniversalGemmKernel
template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
{
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("Non-persistent kernel entry point\n");
// }
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t num_loop,
void* p_smem) const
{
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem)\n");
printf("num_loop: %d\n", num_loop);
printf("PrefetchStages: %d\n", PrefetchStages);
printf("HasHotLoop: %d\n", HasHotLoop);
}
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
Expand Down Expand Up @@ -286,6 +293,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");

// print window lengths
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("A block window lengths: %d %d\n", ADramBlockWindowTmp{}.get_window_lengths()[I0{}].value, ADramBlockWindowTmp{}.get_window_lengths()[I1{}].value);
printf("B block window lengths: %d %d\n", BDramBlockWindowTmp{}.get_window_lengths()[I0{}].value, BDramBlockWindowTmp{}.get_window_lengths()[I1{}].value);
}

// ------------------------------------------------------------------------------------
// Definitions of all needed tiles

Expand Down Expand Up @@ -508,6 +522,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};

if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("TailNum = %d\n", static_cast<int>(num_loop % PrefetchStages));
}
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
Expand Down Expand Up @@ -963,6 +981,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t num_loop,
void* p_smem) const
{
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem)\n");
// }
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
};

template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;

Expand Down Expand Up @@ -491,7 +491,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
void* p_smem,
index_t m = 0) const
{
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
if(get_thread_id() == 0 && get_block_id() == 0)
{
printf("scheduler: %s\n", Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave");
}
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const BDataType& a) { return a; },
Expand Down
Loading