-
Notifications
You must be signed in to change notification settings - Fork 265
[CK-Tile] add persistent async input scheduler parameters to kernel device-side and host-side args #3520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[CK-Tile] add persistent async input scheduler parameters to kernel device-side and host-side args #3520
Changes from all commits
cc72644
c11a659
a72ac43
92f9124
eb88bf2
3c2e0e0
f110598
8964cce
8a40d8f
e06e20c
d18488f
af1265a
46e68f0
b056203
452f228
bb99d97
e0ce8f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,11 @@ | |
| // SPDX-License-Identifier: MIT | ||
| #pragma once | ||
| #include <functional> | ||
| #include <chrono> | ||
| #include <thread> | ||
| #include "gemm_utils.hpp" | ||
| #include "ck_tile/host/hip_check_error.hpp" | ||
| #include "ck_tile/host/device_memory.hpp" | ||
|
|
||
| struct UniversalInvoker | ||
| { | ||
|
|
@@ -150,4 +154,172 @@ struct UniversalInvoker | |
| preprocess, | ||
| ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); | ||
| } | ||
|
|
||
| template <typename GemmConfig, | ||
| typename ADataType, | ||
| typename BDataType, | ||
| typename DsDataType, | ||
| typename AccDataType, | ||
| typename CDataType, | ||
| typename ALayout, | ||
| typename BLayout, | ||
| typename DsLayout, | ||
| typename ELayout, | ||
| typename CDEElementWise> | ||
| static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args, | ||
| const ck_tile::stream_config& s) | ||
| { | ||
| using GemmShape = ck_tile::TileGemmShape< | ||
| ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>, | ||
| ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>, | ||
| ck_tile:: | ||
| sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>, | ||
| GemmConfig::PermuteA, | ||
| GemmConfig::PermuteB>; | ||
|
|
||
| using TilePartitioner = | ||
| ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape, | ||
| GemmConfig::TileParitionerGroupNum, | ||
| GemmConfig::TileParitionerM01>; | ||
|
|
||
| using GemmUniversalTraits = | ||
| ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM, | ||
| GemmConfig::kPadN, | ||
| GemmConfig::kPadK, | ||
| GemmConfig::DoubleSmemBuffer, | ||
| ALayout, | ||
| BLayout, | ||
| ELayout, | ||
| GemmConfig::TransposeC, | ||
| GemmConfig::UseStructuredSparsity, | ||
| true, // Persistent = true for async test | ||
| GemmConfig::NumWaveGroups, | ||
| GemmConfig::Preshuffle>; | ||
|
|
||
| constexpr auto scheduler = GemmConfig::Scheduler; | ||
|
|
||
| using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType, | ||
| BDataType, | ||
| AccDataType, | ||
| GemmShape, | ||
| GemmUniversalTraits, | ||
| scheduler>; | ||
|
|
||
| using GemmPipeline = typename PipelineTypeTraits< | ||
| GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>; | ||
|
|
||
| using GemmEpilogue = ck_tile::CShuffleEpilogue< | ||
| ck_tile::CShuffleEpilogueProblem<ADataType, | ||
| BDataType, | ||
| DsDataType, | ||
| AccDataType, | ||
| CDataType, | ||
| DsLayout, | ||
| ELayout, | ||
| CDEElementWise, | ||
| TilePartitioner::MPerBlock, | ||
| TilePartitioner::NPerBlock, | ||
| GemmConfig::M_Warp, | ||
| GemmConfig::N_Warp, | ||
| GemmConfig::M_Warp_Tile, | ||
| GemmConfig::N_Warp_Tile, | ||
| GemmConfig::K_Warp_Tile, | ||
| UniversalGemmProblem::TransposeC, | ||
| GemmConfig::NumWaveGroups, | ||
| false, | ||
| 1, | ||
| false, | ||
| 1, | ||
| GemmConfig::DoubleSmemBuffer>>; | ||
|
|
||
| using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; | ||
|
|
||
| // Calculate number of M tiles and chunks | ||
| const ck_tile::index_t tiles_m = | ||
| (args.M + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could use the integer_divide_ceil |
||
| const ck_tile::index_t tiles_per_chunk = 2; // 2 tiles per chunk | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could I know why we are using 2 tiles per chunk? |
||
| // Pivot: first chunk of tiles (tiles 0 to tiles_per_chunk-1) are immediately available | ||
| // Only tiles from tile_idx_pivot_m onwards need to wait for signals | ||
| const ck_tile::index_t tile_idx_pivot = tiles_per_chunk; // Skip first chunk | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pivot point should be (m + tile_idx_pivot_m) % tiles_m |
||
| const ck_tile::index_t tiles_needing_signals = | ||
| (tiles_m > tile_idx_pivot) ? (tiles_m - tile_idx_pivot) : 0; | ||
| const ck_tile::index_t num_chunks = | ||
| (tiles_needing_signals + tiles_per_chunk - 1) / tiles_per_chunk; | ||
|
|
||
| std::cout << "Async Input Scheduler Test:" << std::endl; | ||
| std::cout << " M tiles: " << tiles_m << std::endl; | ||
| std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl; | ||
| std::cout << " Tile index pivot: " << tile_idx_pivot << " (first " << tile_idx_pivot | ||
| << " tiles don't wait)" << std::endl; | ||
| std::cout << " Tiles needing signals: " << tiles_needing_signals << std::endl; | ||
| std::cout << " Number of signal chunks: " << num_chunks << std::endl; | ||
|
|
||
| // Allocate chunk signals using ck_tile::DeviceMem (initialized to zero) | ||
| // Only need signals for chunks beyond the pivot | ||
| ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t)); | ||
| signal_buf.SetZero(); | ||
| uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer()); | ||
|
|
||
| // Setup async input scheduler | ||
| ck_tile::PersistentAsyncInputScheduler async_scheduler; | ||
| async_scheduler.tiles_per_chunk_m = tiles_per_chunk; | ||
| async_scheduler.chunk_signals = d_chunk_signals; | ||
| async_scheduler.tile_idx_pivot_m = tile_idx_pivot; | ||
|
|
||
| // Create modified host args with async scheduler | ||
| ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr}, | ||
| {args.b_ptr}, | ||
| {}, | ||
| args.e_ptr, | ||
| args.k_batch, | ||
| args.M, | ||
| args.N, | ||
| args.K, | ||
| {args.stride_A}, | ||
| {args.stride_B}, | ||
| {}, | ||
| args.stride_E, | ||
| async_scheduler); | ||
|
|
||
| auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args); | ||
|
|
||
| const dim3 grids = Kernel::MaxOccupancyGridSize(s); | ||
| const dim3 blocks = Kernel::BlockSize(); | ||
|
|
||
| std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" | ||
| << std::endl; | ||
| std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" | ||
| << std::endl; | ||
|
|
||
| // Create a separate stream for setting signals | ||
| // (using the same stream would deadlock - memcpy waits for kernel, kernel waits for signal) | ||
| hipStream_t signal_stream; | ||
| HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking)); | ||
|
|
||
| const auto start = std::chrono::high_resolution_clock::now(); | ||
|
|
||
| ck_tile::launch_kernel( | ||
| s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); | ||
|
|
||
| // Set signals with interleaved sleep using a separate stream | ||
| const int sleep_us = 100; | ||
| for(ck_tile::index_t i = 0; i < num_chunks; ++i) | ||
| { | ||
| std::this_thread::sleep_for(std::chrono::microseconds(sleep_us)); | ||
| const uint32_t signal_val = 1; | ||
| HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i, | ||
| &signal_val, | ||
| sizeof(uint32_t), | ||
| hipMemcpyHostToDevice, | ||
| signal_stream)); | ||
| } | ||
| HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream)); | ||
| HIP_CHECK_ERROR(hipStreamDestroy(signal_stream)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a line of HIP_CHECK_ERROR(hipDeviceSynchronize()); ? |
||
|
|
||
| auto duration = std::chrono::duration_cast<std::chrono::microseconds>( | ||
| std::chrono::high_resolution_clock::now() - start); | ||
|
|
||
| std::cout << " Total time: " << duration.count() << " us" << std::endl; | ||
| std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl; | ||
| } | ||
| }; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| namespace ck_tile { | ||
|
|
||
| // the fields may be set on the client side | ||
| struct PersistentAsyncInputScheduler | ||
| { | ||
| uint32_t tiles_per_chunk_m = 0; | ||
|
|
||
| uint32_t* chunk_signals = nullptr; | ||
|
|
||
| int32_t tile_idx_pivot_m = 0; | ||
| }; | ||
|
|
||
| } // namespace ck_tile |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,8 @@ | |
| #include "ck_tile/host/stream_utils.hpp" | ||
| #include "ck_tile/core/utility/env.hpp" | ||
| #include "ck_tile/core/utility/type_traits.hpp" | ||
| #include "ck_tile/core/utility/persistent_async_input_scheduler.hpp" | ||
| #include "ck_tile/core/arch/workgroup_barrier.hpp" | ||
|
|
||
| namespace ck_tile { | ||
|
|
||
|
|
@@ -30,18 +32,20 @@ namespace ck_tile { | |
| template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0> | ||
| struct UniversalGemmHostArgs | ||
| { | ||
| CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_, | ||
| const std::array<const void*, NumBTensor>& bs_ptr_, | ||
| const std::array<const void*, NumDTensor>& ds_ptr_, | ||
| void* e_ptr_, | ||
| index_t k_batch_, | ||
| index_t M_, | ||
| index_t N_, | ||
| index_t K_, | ||
| const std::array<index_t, NumATensor>& stride_As_, | ||
| const std::array<index_t, NumBTensor>& stride_Bs_, | ||
| const std::array<index_t, NumDTensor>& stride_Ds_, | ||
| index_t stride_E_) | ||
| CK_TILE_HOST UniversalGemmHostArgs( | ||
| const std::array<const void*, NumATensor>& as_ptr_, | ||
| const std::array<const void*, NumBTensor>& bs_ptr_, | ||
| const std::array<const void*, NumDTensor>& ds_ptr_, | ||
| void* e_ptr_, | ||
| index_t k_batch_, | ||
| index_t M_, | ||
| index_t N_, | ||
| index_t K_, | ||
| const std::array<index_t, NumATensor>& stride_As_, | ||
| const std::array<index_t, NumBTensor>& stride_Bs_, | ||
| const std::array<index_t, NumDTensor>& stride_Ds_, | ||
| index_t stride_E_, | ||
| PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{}) | ||
| : as_ptr(as_ptr_), | ||
| bs_ptr(bs_ptr_), | ||
| ds_ptr(ds_ptr_), | ||
|
|
@@ -53,7 +57,8 @@ struct UniversalGemmHostArgs | |
| stride_Bs(stride_Bs_), | ||
| stride_Ds(stride_Ds_), | ||
| stride_E(stride_E_), | ||
| k_batch(k_batch_) | ||
| k_batch(k_batch_), | ||
| async_input_scheduler(async_input_scheduler_) | ||
| { | ||
| } | ||
|
|
||
|
|
@@ -78,6 +83,7 @@ struct UniversalGemmHostArgs | |
| }; | ||
|
|
||
| index_t k_batch; | ||
| PersistentAsyncInputScheduler async_input_scheduler; | ||
| }; | ||
|
|
||
| /// @brief The GEMM kernel device arguments. | ||
|
|
@@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs | |
| /// (in memory) of E tensor. | ||
| index_t stride_E; | ||
| index_t k_batch; | ||
| /// @brief Persistent async input scheduler for chunk-based tile scheduling. | ||
| PersistentAsyncInputScheduler async_input_scheduler = {}; | ||
| }; | ||
|
|
||
| /// @brief The Universal GEMM kernel template. | ||
|
|
@@ -315,7 +323,8 @@ struct UniversalGemmKernel | |
| hostArgs.stride_Bs, | ||
| hostArgs.stride_Ds, | ||
| hostArgs.stride_E, | ||
| hostArgs.k_batch}; | ||
| hostArgs.k_batch, | ||
| hostArgs.async_input_scheduler}; | ||
| } | ||
|
|
||
| CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() | ||
|
|
@@ -1183,6 +1192,25 @@ struct UniversalGemmKernel | |
| const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); | ||
| const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); | ||
|
|
||
| // Wait for async input chunk if scheduler is configured | ||
| if(kargs.async_input_scheduler.chunk_signals != nullptr) | ||
| { | ||
| const auto tiles_per_chunk = | ||
| amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m); | ||
| const auto tile_idx_pivot = | ||
| amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m); | ||
| // Only wait for tiles at or beyond the pivot (tiles before pivot are already | ||
| // available) | ||
| if(tiles_per_chunk > 0 && iM >= tile_idx_pivot) | ||
| { | ||
| // Chunk index is relative to the pivot | ||
| const auto chunk_idx = | ||
| amd_wave_read_first_lane((iM - tile_idx_pivot) / tiles_per_chunk); | ||
| workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals); | ||
| chunk_barrier.wait_eq(/*value=*/1, /*offset=*/chunk_idx); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checked the wait_eq function may be add a line of "__builtin_amdgcn_s_sleep(1); // Reduce power consumption"? |
||
| } | ||
| } | ||
|
|
||
| // Get the SplitK offset for this block | ||
| const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles); | ||
| const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment for the false and 1 meaning?