Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions example/01_gemm/gemm_wmma_fp16_v3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
PassThrough, PassThrough, PassThrough, GemmSpec,
256,
128, 256, 64,
8, 8,
16, 16,
2, 8,
S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 8, 8, 1,
1, 1,
S<1, 64, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
// check if src element is valid
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);

// Vector length of elementwise operation
constexpr auto get_elem_op_vec_len = []() {
Expand Down Expand Up @@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
using dst_vector_t = typename dst_vector_type::type;

using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;

dst_vector_type op_r_v;

// Load data from memory in src_vector first
src_vector_container src_vector =
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
src_coord_.GetOffset(), true)};
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
src_vector_container src_vector = src_vector_container{
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};

// apply the src elementwise op and convert to DstData under the hood if needed
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
Expand All @@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
// store result in dvgpr_ (static array holding loaded data).
// At this point data is already converted to DstData type and
// the elementwise operation has been applied
dvgpr_.template SetAsType<dst_vector_t>(
vgpr_data_idx_seq,
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);

// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
Expand Down Expand Up @@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
constexpr auto ordered_fwd_step = StepsPerIteration{};

// OOB check
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// calculate src data index and make sequence
constexpr auto src_data_idx = [&]() {
Index ordered_idx;

static_for<0, nDim, 1>{}(
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });

return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
}();

// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
[&](auto i) {
if constexpr(i.value < src_data_idx.Size())
{
return Number<src_data_idx[i]>{};
}
else
{
return Number<0>{};
}
},
Number<src_data_idx.Size() + 1>{});

auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
const bool is_src_valid =
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
});

// make forward steps
// forward step for each iteration just add 1
const auto dst_forward_steps = generate_tuple(
Expand Down Expand Up @@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
true,
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));

// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
Expand Down Expand Up @@ -389,14 +420,32 @@ struct ThreadGroupTransferGlobal
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}

__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto access_lengths_as_tuple =
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});

return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}

static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
VectorSize,
decltype(thread_data_scratch_desc_),
true>;

ThreadScratchData dvgpr_;
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool,
1,
decltype(src_oob_thread_scratch_desc_),
true>;

ThreadScratchData src_dvgpr_;
ThreadScratchData dst_dvgpr_;
OOBThreadScratch oob_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

return GridwiseGemm::CheckValidity(arg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

return GridwiseGemm::CheckValidity(arg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

return GridwiseGemm::CheckValidity(arg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

typename GridwiseGemmWelford::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

return GridwiseGemm::CheckValidity(arg);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

return GridwiseGemm::CheckValidity(
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
AComputeType,
BComputeType,
false,
false>;
false,
false,
true>;

#define GridwiseGemmCTransposeTemplateParameters \
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
Expand All @@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
AComputeType, false, false
AComputeType, false, false, false, true

using GridwiseGemmCTranspose =
std::conditional_t<CTranspose,
Expand Down
Loading