Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ concept EpilogueDescriptor = requires(T t) {

// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
concept ThreadClusterOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};

Expand Down Expand Up @@ -195,16 +195,16 @@ concept SpecifiesLdsTransfer = requires(T t) {

// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
concept SpecifiesThreadClusterArrangeOrder = requires(T t) {
{ T::transfer.a.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
{ T::transfer.b.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
};

// Concept to check if a struct specifies source access order info.
template <typename T>
concept SpecifiesSourceAccessOrder = requires(T t) {
{ T::transfer.a.src_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.src_access_order } -> AccessOrderDescriptor;
{ T::transfer.a.src_access_order } -> ThreadClusterOrderDescriptor;
{ T::transfer.b.src_access_order } -> ThreadClusterOrderDescriptor;
};

// Concept to check if struct specifies block GEMM.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

#include <type_traits>
#include <concepts>
#include <utility>
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/arch/arch.hpp"

namespace ck_tile::builder {

Expand Down Expand Up @@ -35,4 +38,178 @@ concept AccessOrderLimits = requires {
(Value[2] >= 0 && Value[2] < 3));
};

namespace detail {

// Helper trait to get compile-time size from ck::Array
template <typename T>
concept HasStaticSize = requires {
{ T::Size() } -> std::convertible_to<size_t>;
};

// Helper trait to get compile-time size from std::array and similar
template <typename T>
concept HasTupleSize = requires {
{ std::tuple_size<T>::value } -> std::convertible_to<size_t>;
};

// Helper for dependent static_assert
template <typename>
constexpr bool always_false = false;

// Get compile-time size of a range
template <typename Range>
constexpr size_t get_range_size()
{
if constexpr(HasStaticSize<Range>)
{
return Range::Size();
}
else if constexpr(HasTupleSize<Range>)
{
return std::tuple_size_v<Range>;
}
else
{
static_assert(always_false<Range>, "Unsupported type of range object.");
}
}

// Fold expression implementation for product calculation
template <typename Range, size_t... Is>
constexpr auto get_cluster_size_impl(const Range& range, std::index_sequence<Is...>)
{
using value_type = std::remove_cvref_t<decltype(range[0])>;
return ((range[Is]) * ... * value_type{1});
}

// Generic function that calculates the product of all elements in a range
// Works with any indexable range with compile-time size (ck::Array, std::array, etc.)
template <typename Range>
requires requires(Range r) {
r[0]; // Must be indexable
get_range_size<Range>(); // Must have compile-time size
}
constexpr auto get_cluster_size(const Range& range)
{
return get_cluster_size_impl(range, std::make_index_sequence<get_range_size<Range>()>{});
}

// Calculate K dimension coverage (k0 * k1, with vectorization if applicable)
template <auto BlockTransfer>
constexpr auto get_k_coverage()
{
auto k0 = BlockTransfer.thread_cluster_dims[0];
auto k1 = BlockTransfer.thread_cluster_dims[2];
auto k_total = k0 * k1;

// If vectorization is on k0 (dim 0) or k1 (dim 2), multiply by vector size
if constexpr(BlockTransfer.src_vector_dim == 0 || BlockTransfer.src_vector_dim == 2)
{
k_total *= BlockTransfer.src_scalar_per_vector;
}

return k_total;
}

// Calculate M/N dimension coverage (m_n, with vectorization if applicable)
template <auto BlockTransfer>
constexpr auto get_mn_coverage()
{
auto mn = BlockTransfer.thread_cluster_dims[1];

// If vectorization is on m_n (dim 1), multiply by vector size
if constexpr(BlockTransfer.src_vector_dim == 1)
{
mn *= BlockTransfer.src_scalar_per_vector;
}

return mn;
}

template <size_t DataTypeSize>
constexpr auto get_data_max_vec_size()
{
constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width();
static_assert(max_vec_inst_size_bytes % DataTypeSize == 0,
"The max vec instruction size is not a multiple of given data type size.");
return max_vec_inst_size_bytes / DataTypeSize;
}

} // namespace detail

// product of thread cluster lengths must be <= workgroup size
template <auto BlockTransfer, size_t BlockSize>
concept ValidBlockTransferClusterSize =
requires { requires detail::get_cluster_size(BlockTransfer.thread_cluster_dims) <= BlockSize; };

// Check that thread cluster covers the K and M dimensions for A transfer
template <auto ABlockTransfer, auto TileSize>
concept ThreadsCoverATile = requires {
// K dimension: k0 * k1 * (vectorization) must divide K
requires TileSize.k % detail::get_k_coverage<ABlockTransfer>() == 0;
// M dimension: m_n * (vectorization) must divide M
requires TileSize.m % detail::get_mn_coverage<ABlockTransfer>() == 0;
};

// Check that thread cluster covers the K and N dimensions for B transfer
template <auto BBlockTransfer, auto TileSize>
concept ThreadsCoverBTile = requires {
// K dimension: k0 * k1 * (vectorization) must divide K
requires TileSize.k % detail::get_k_coverage<BBlockTransfer>() == 0;
// N dimension: m_n * (vectorization) must divide N
requires TileSize.n % detail::get_mn_coverage<BBlockTransfer>() == 0;
};

template <auto CBlockTransfer, auto TileSize>
concept ThreadsCoverCTile = requires {
// M dimension: m_wave_per_xdl must divide M
requires TileSize.m % CBlockTransfer.thread_cluster_dims[1] == 0;
// N dimension: n_wave_per_xdl * (vectorization) must divide N
requires TileSize.n % (CBlockTransfer.thread_cluster_dims[3] *
CBlockTransfer.scalar_per_vector) == 0;
};

template <size_t Value>
concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0);

template <size_t ScalarPerVec, size_t DataTypeSize>
concept IsVectorSizeValid =
IsPowerOf2<ScalarPerVec> && (ScalarPerVec <= detail::get_data_max_vec_size<DataTypeSize>());

// Composite concept for input block transfer validation (A)
// Includes all validations: vector transfer limits, access order, cluster size,
// vector size validity, and tile coverage
template <auto A_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
concept ValidABlockTransfer =
InputVectorTransferLimits<A_BLOCK_TRANSFER> &&
AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order> &&
AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order> &&
ValidBlockTransferClusterSize<A_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverATile<A_BLOCK_TRANSFER, TILE_SIZE>;

// Composite concept for input block transfer validation (B)
template <auto B_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
concept ValidBBlockTransfer =
InputVectorTransferLimits<B_BLOCK_TRANSFER> &&
AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order> &&
AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order> &&
ValidBlockTransferClusterSize<B_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverBTile<B_BLOCK_TRANSFER, TILE_SIZE>;

// Composite concept for output block transfer validation (C)
template <auto C_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
concept ValidCBlockTransfer =
OutputVectorTransferLimits<C_BLOCK_TRANSFER> &&
ValidBlockTransferClusterSize<C_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<C_BLOCK_TRANSFER.scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverCTile<C_BLOCK_TRANSFER, TILE_SIZE>;

// Usage: IsValidLayout<ACTUAL_LAYOUT, VALID_LAYOUT_1, VALID_LAYOUT_2, ...>
template <auto ACTUAL_LAYOUT, auto... VALID_LAYOUTS>
concept IsValidLayout = ck_tile::is_any_value_of(ACTUAL_LAYOUT, VALID_LAYOUTS...);

} // namespace ck_tile::builder
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock
template <typename T>
concept IsXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterArrangeOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;

// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
template <typename T>
concept IsXdlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterArrangeOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
Expand All @@ -121,7 +121,7 @@ concept IsXdlAlgorithm =
template <typename T>
concept IsWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterArrangeOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,55 @@ struct ConvFwdLargeTensorFactory
static constexpr auto C_BLOCK_TRANSFER =
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();

// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// Check limits for the data transfer parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::ADataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::BDataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::EDataType,
BLOCK.block_size,
BLOCK.per_block>);

using enum TensorLayout;
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
G_NW_C_strided,
G_NHW_C_strided,
G_NDHW_C_strided,
GNWC,
GNHWC,
GNDHWC,
NWGC,
NHWGC,
NDHWGC> &&
A_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
G_K_X_C_strided,
G_K_YX_C_strided,
G_K_ZYX_C_strided,
GKXC,
GKYXC,
GKZYXC,
KXGC,
KYXGC,
KZYXGC> &&
B_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.output.config.layout,
G_NW_K_strided,
G_NHW_K_strided,
G_NDHW_K_strided,
GNWK,
GNHWK,
GNDHWK,
NWGK,
NHWGK,
NDHWGK>);

// The forward convolution kernel class instance with large tensor support.
using Instance =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,64 @@ struct ConvFwdXdlV3Factory
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();

// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::ADataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::BDataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::EDataType,
BLOCK.block_size,
BLOCK.per_block>);

// Layout validations
using enum TensorLayout;
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
G_NW_C_strided,
G_NHW_C_strided,
G_NDHW_C_strided,
GNWC,
GNHWC,
GNDHWC,
NWGC,
NHWGC,
NDHWGC,
NGCW,
NGCHW,
NGCDHW> &&
A_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
G_K_X_C_strided,
G_K_YX_C_strided,
G_K_ZYX_C_strided,
GKXC,
GKYXC,
GKZYXC,
KXGC,
KYXGC,
KZYXGC,
GKCX,
GKCYX,
GKCZYX> &&
B_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.output.config.layout,
G_NW_K_strided,
G_NHW_K_strided,
G_NDHW_K_strided,
GNWK,
GNHWK,
GNDHWK,
NWGK,
NHWGK,
NDHWGK,
NGKW,
NGKHW,
NGKDHW>);

// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
Expand Down
Loading