Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ inline auto create_args()
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "gemm.json", "json file name to dump results")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
return arg_parser;
}

Expand Down
49 changes: 49 additions & 0 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,55 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

// Check for async input scheduler test mode
bool test_async = arg_parser.get_int("test_async");
if(test_async)
{
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};

Invoker::template test_async_input_scheduler<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough>(
args, ck_tile::stream_config{nullptr, false, 1});

// Copy result from device for verification
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

// Compute CPU reference
ck_tile::HostTensor<CDataType> c_m_n_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);

// Verify results
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");

std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
return pass;
}

float ave_time = invoke_gemm<GemmConfig,
Invoker,
ADataType,
Expand Down
172 changes: 172 additions & 0 deletions example/ck_tile/03_gemm/universal_gemm_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
// SPDX-License-Identifier: MIT
#pragma once
#include <functional>
#include <chrono>
#include <thread>
#include "gemm_utils.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/device_memory.hpp"

struct UniversalInvoker
{
Expand Down Expand Up @@ -150,4 +154,172 @@ struct UniversalInvoker
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;

using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;

using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
true, // Persistent = true for async test
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;

constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups,
false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment for the false and 1 meaning?

1,
false,
1,
GemmConfig::DoubleSmemBuffer>>;

using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;

// Calculate number of M tiles and chunks
const ck_tile::index_t tiles_m =
(args.M + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use the integer_divide_ceil

const ck_tile::index_t tiles_per_chunk = 2; // 2 tiles per chunk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I know why we are using 2 tiles per chunk?

// Pivot: first chunk of tiles (tiles 0 to tiles_per_chunk-1) are immediately available
// Only tiles from tile_idx_pivot_m onwards need to wait for signals
const ck_tile::index_t tile_idx_pivot = tiles_per_chunk; // Skip first chunk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pivot point should be (m + tile_idx_pivot_m) % tiles_m

const ck_tile::index_t tiles_needing_signals =
(tiles_m > tile_idx_pivot) ? (tiles_m - tile_idx_pivot) : 0;
const ck_tile::index_t num_chunks =
(tiles_needing_signals + tiles_per_chunk - 1) / tiles_per_chunk;

std::cout << "Async Input Scheduler Test:" << std::endl;
std::cout << " M tiles: " << tiles_m << std::endl;
std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl;
std::cout << " Tile index pivot: " << tile_idx_pivot << " (first " << tile_idx_pivot
<< " tiles don't wait)" << std::endl;
std::cout << " Tiles needing signals: " << tiles_needing_signals << std::endl;
std::cout << " Number of signal chunks: " << num_chunks << std::endl;

// Allocate chunk signals using ck_tile::DeviceMem (initialized to zero)
// Only need signals for chunks beyond the pivot
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
signal_buf.SetZero();
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());

// Setup async input scheduler
ck_tile::PersistentAsyncInputScheduler async_scheduler;
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
async_scheduler.chunk_signals = d_chunk_signals;
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;

// Create modified host args with async scheduler
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr},
{args.b_ptr},
{},
args.e_ptr,
args.k_batch,
args.M,
args.N,
args.K,
{args.stride_A},
{args.stride_B},
{},
args.stride_E,
async_scheduler);

auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);

const dim3 grids = Kernel::MaxOccupancyGridSize(s);
const dim3 blocks = Kernel::BlockSize();

std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< std::endl;
std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;

// Create a separate stream for setting signals
// (using the same stream would deadlock - memcpy waits for kernel, kernel waits for signal)
hipStream_t signal_stream;
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));

const auto start = std::chrono::high_resolution_clock::now();

ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

// Set signals with interleaved sleep using a separate stream
const int sleep_us = 100;
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
{
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
const uint32_t signal_val = 1;
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
&signal_val,
sizeof(uint32_t),
hipMemcpyHostToDevice,
signal_stream));
}
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a line of HIP_CHECK_ERROR(hipDeviceSynchronize()); ?


auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::high_resolution_clock::now() - start);

std::cout << " Total time: " << duration.count() << " us" << std::endl;
std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl;
}
};
1 change: 1 addition & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <cstdint>

namespace ck_tile {

// the fields may be set on the client side
struct PersistentAsyncInputScheduler
{
uint32_t tiles_per_chunk_m = 0;

uint32_t* chunk_signals = nullptr;

int32_t tile_idx_pivot_m = 0;
};

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ struct StreamKKernel
host_args.stride_Bs,
host_args.stride_Ds,
host_args.stride_E,
host_args.k_batch},
host_args.k_batch,
host_args.async_input_scheduler},
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
Expand Down
56 changes: 42 additions & 14 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/arch/workgroup_barrier.hpp"

namespace ck_tile {

Expand All @@ -30,18 +32,20 @@ namespace ck_tile {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct UniversalGemmHostArgs
{
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
CK_TILE_HOST UniversalGemmHostArgs(
const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_,
PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{})
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
Expand All @@ -53,7 +57,8 @@ struct UniversalGemmHostArgs
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
k_batch(k_batch_),
async_input_scheduler(async_input_scheduler_)
{
}

Expand All @@ -78,6 +83,7 @@ struct UniversalGemmHostArgs
};

index_t k_batch;
PersistentAsyncInputScheduler async_input_scheduler;
};

/// @brief The GEMM kernel device arguments.
Expand Down Expand Up @@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs
/// (in memory) of E tensor.
index_t stride_E;
index_t k_batch;
/// @brief Persistent async input scheduler for chunk-based tile scheduling.
PersistentAsyncInputScheduler async_input_scheduler = {};
};

/// @brief The Universal GEMM kernel template.
Expand Down Expand Up @@ -315,7 +323,8 @@ struct UniversalGemmKernel
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
hostArgs.k_batch,
hostArgs.async_input_scheduler};
}

CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
Expand Down Expand Up @@ -1183,6 +1192,25 @@ struct UniversalGemmKernel
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);

// Wait for async input chunk if scheduler is configured
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
const auto tiles_per_chunk =
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
const auto tile_idx_pivot =
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
// Only wait for tiles at or beyond the pivot (tiles before pivot are already
// available)
if(tiles_per_chunk > 0 && iM >= tile_idx_pivot)
{
// Chunk index is relative to the pivot
const auto chunk_idx =
amd_wave_read_first_lane((iM - tile_idx_pivot) / tiles_per_chunk);
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
chunk_barrier.wait_eq(/*value=*/1, /*offset=*/chunk_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the wait_eq function may be add a line of "__builtin_amdgcn_s_sleep(1); // Reduce power consumption"?

}
}

// Get the SplitK offset for this block
const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
Expand Down
Loading