Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
873688f
feat: test setup for batched contraction (aka batched gemm multiple d…
ErwinTerpstra Dec 19, 2025
c78c353
wip: device struct for WMMA batched contraction multiple d based on n…
ErwinTerpstra Dec 19, 2025
0fae879
feat: working batched contraction on RDNA, non-naive tensor descripto…
ErwinTerpstra Jan 7, 2026
82323d2
fix: failure to resolve template parameters when calling new function…
ErwinTerpstra Jan 8, 2026
c67a425
fix: passing reference type as parameter instead of underlying types
ErwinTerpstra Jan 8, 2026
ac82e53
Merge branch 'develop' into eterpstr/96-implement-device_batched_gemm…
ErwinTerpstra Jan 8, 2026
3c6fc61
fix: merge error caused duplicate definitions
ErwinTerpstra Jan 8, 2026
56e9620
fix: make sure constness of template and parameters types match
ErwinTerpstra Jan 8, 2026
918981d
fix: don't compile batched contraction test on unsupported architectures
ErwinTerpstra Jan 9, 2026
7a201c9
feat: add example for new wmma implementation, and consolidate exampl…
ErwinTerpstra Jan 15, 2026
d01b8f6
style: return inline instead of with branch
ErwinTerpstra Jan 15, 2026
55959b0
chore: add extra assert on vector memory access sizes
ErwinTerpstra Jan 15, 2026
51f8d41
chore: clean up some unused variables
ErwinTerpstra Jan 15, 2026
96612c1
fix: correct tail number calculation, added small cases and extra ins…
ErwinTerpstra Jan 15, 2026
7759d80
Merge branch 'develop' into eterpstr/96-implement-device_batched_gemm…
ErwinTerpstra Jan 15, 2026
b7e97c8
fix: merge caused duplicate function definition
ErwinTerpstra Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"

#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"

using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
Expand Down Expand Up @@ -69,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::

using DeviceOpInstance = DeviceOpInstanceKKNN;

// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}

const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;

AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};

// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M2_N3_K1::Argument;

float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];

AccDataType v_acc = 0;

for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;

arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));

v_acc += v_a * v_b;
}

AccDataType v_c;

arg.cde_element_op_(v_c, v_acc);

arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
};

make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());

return 0;
}

float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};

static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}

bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}

static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}

static auto MakeInvoker() { return Invoker{}; }

virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}

std::string GetTypeString() const override
{
auto str = std::stringstream();

// clang-format off
str << "ReferenceContraction_M3_N2_K1"
<< std::endl;
// clang-format on

return str.str();
}
};

int main(int argc, char* argv[])
{
bool do_verification = true;
Expand Down Expand Up @@ -353,16 +217,18 @@ int main(int argc, char* argv[])
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});

using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;

auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Expand Down Expand Up @@ -399,7 +265,13 @@ int main(int argc, char* argv[])
}
}

return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;

if(!pass)
{
return 1;
}
}

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"

#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"

using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::make_ParallelTensorFunctor;
Expand Down Expand Up @@ -67,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::

using DeviceOpInstance = DeviceOpInstanceKKNN;

template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}

const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;

AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};

// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M3_N2_K1::Argument;

float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];

AccDataType v_acc = 0;

for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;

arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));

v_acc += v_a * v_b;
}

AccDataType v_c;

arg.cde_element_op_(v_c, v_acc);

arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
};

make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());

return 0;
}

float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};

static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}

bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}

static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}

static auto MakeInvoker() { return Invoker{}; }

virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}

std::string GetTypeString() const override
{
auto str = std::stringstream();

// clang-format off
str << "ReferenceContraction_G1_M3_N2_K1"
<< std::endl;
// clang-format on

return str.str();
}
};

int main(int argc, char* argv[])
{
bool do_verification = true;
Expand Down Expand Up @@ -353,17 +219,18 @@ int main(int argc, char* argv[])
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});

using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;

auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Expand Down Expand Up @@ -400,7 +267,13 @@ int main(int argc, char* argv[])
}
}

return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;

if(!pass)
{
return 1;
}
}

return 0;
Expand Down
1 change: 1 addition & 0 deletions example/29_batched_gemm_bias_e_permute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_bias_e_permute_wmma_v3_fp16 batched_gemm_bias_e_permute_wmma_v3_fp16.cpp)
Loading