Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
f7955d9
Add placeholder test.
Dec 18, 2025
b828d35
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-bwd-w…
Dec 19, 2025
2460cf4
Initial conv bwd weight factory.
Dec 19, 2025
5a1c9c9
Conv builder test refactoring.
Dec 19, 2025
1df8077
Add missing pieces to bwd weight factory.
Dec 19, 2025
4d5b5b7
Improve compile time erros message when no matching factory is found.
Dec 22, 2025
4d20cc6
Use amcro to ensure automatic macthing between concepts are their str…
Dec 22, 2025
c6798d3
Improve compile time diagnostics.
Dec 22, 2025
8d40e6d
Small improvements.
Dec 22, 2025
9679d9b
Improve missing member/wrong type compile-time errors.
Dec 22, 2025
5ee99d8
Improve compile time diagnostics.
Dec 22, 2025
dacf82d
Concept bug fixes.
Dec 22, 2025
8eb6224
Remove debug assert.
Dec 22, 2025
a8e7edd
Update algorithm signature diagnostics.
Dec 22, 2025
96a4a5d
Factory bug fixes.
Dec 22, 2025
608266a
First functional version of bwd weight conv factory.
Dec 22, 2025
a1740c6
Refactor handing of GEMM-K batch template parameter in conv bwd weigh…
Dec 23, 2025
77e10c7
Concept improvements.
Dec 23, 2025
ff2fdd8
Improve concept diagnostics.
Dec 29, 2025
8c80e00
Introduve a common size type for concepts.
Dec 29, 2025
30a9686
Update compiletime diagnostics to use the size type.
Dec 29, 2025
027d943
Update conv specialization enum.
Dec 29, 2025
3bd0f05
Fix fwd conv builder tests.
Dec 29, 2025
52086b3
Fix smoke tests.
Dec 29, 2025
9926d94
Separate bwd weigth and bwd data tests into separate targets.
Dec 29, 2025
277981b
Clean-up CK Tile builder tests.
Dec 29, 2025
80f4482
Add bwd weight XDL CShuffle V3 factory.
Dec 29, 2025
a83790e
Build conv bwd weigth v3 instances successfully.
Dec 29, 2025
ab88cee
Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.
Dec 29, 2025
3e16fa0
Test fix.
Dec 29, 2025
3c1e2b0
Add instance traits for bwd weight algorithms.
Dec 30, 2025
adfab9d
Add unit tests for instance strings.
Dec 30, 2025
30c10e2
Build new instance traits unit tests but exclude WMMA for now.
Dec 30, 2025
7571020
Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
Dec 31, 2025
3b0777f
Conv bwd weight DL factory.
Dec 31, 2025
e1b4acd
Final implementation for bwd weight DL factory.
Dec 31, 2025
83be9c7
Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffl…
Dec 31, 2025
fba8040
Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
Dec 31, 2025
5be1ed6
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-bwd-w…
Jan 2, 2026
09e188f
Treat ref algorithm the same way as real algorithms in the dispatcher.
Jan 2, 2026
d045923
Refactor large tensor support and WMMA configuration.
Jan 2, 2026
bc3cba8
Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.
Jan 2, 2026
2e43e16
Update Readme.
Jan 2, 2026
8993427
Fix WMMA bwd weight tests.
Jan 2, 2026
aa10d65
Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_C…
Jan 2, 2026
1759db7
Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.
Jan 2, 2026
c3a9044
Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.
Jan 2, 2026
4eea42c
Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
Jan 2, 2026
1dcea18
Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and …
Jan 5, 2026
829eabe
Fix fwd factories after refactoring.
Jan 5, 2026
881bf91
clang-format
Jan 5, 2026
2010396
Move compile-time diagnostics to a separate branch.
Jan 5, 2026
5f63955
WIP: Unify warp GEMM and thread distribution descriptions.
Jan 5, 2026
02243ca
Merge branch 'develop' into vpietila/ckb-bwd-weight-factories
vpietila-amd Jan 5, 2026
37e9547
Fix ref algorithm dispatching.
Jan 7, 2026
8280322
Fix smoke tests.
Jan 7, 2026
00f45cc
clang-format
Jan 7, 2026
c5cdd51
Fix factory for regular WMMA conv bwd weight.
Jan 7, 2026
d107b85
Merge branch 'develop' into vpietila/ckb-bwd-weight-factories
vpietila-amd Jan 7, 2026
7b3aca7
Merge remote-tracking branch 'origin/vpietila/ckb-bwd-weight-factorie…
Jan 7, 2026
6c41727
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-bwd-w…
Jan 8, 2026
16b8680
Clarify builder Readme.
Jan 8, 2026
fd8edf9
Remove obsolete test file.
Jan 8, 2026
18c2631
Fix test after merge.
Jan 8, 2026
1abe9ab
Merge branch 'vpietila/ckb-bwd-weight-factories' into vpietila/ckb-re…
Jan 8, 2026
0336ac5
clang-format
Jan 8, 2026
3f0bac4
Fix conv algorithm types after refactoring.
Jan 8, 2026
f74e034
Adapt factories to warp GEMM and transfer parameters refactoring.
Jan 9, 2026
63fc27b
Refactor algorithm specialization and GEMM pipeline definitions.
Jan 9, 2026
6bcdc10
Fix fwd factories after refactoring.
Jan 9, 2026
2fe054e
Merge branch 'develop' into vpietila/ckb-bwd-weight-factories
vpietila-amd Jan 12, 2026
7e02790
Remove the C++26 extensions.
Jan 12, 2026
46afc66
Fix fwd/bwd conv factory tests after tile transfer XDL/WMMA concepts …
Jan 12, 2026
4f721ac
Fix remaining fwd/bwd instances tests.
Jan 12, 2026
3e8f390
Unify conv elementwise ops and layout definitions for fwd and bwd dir…
Jan 13, 2026
b5d060b
Remove old layout and elementwise ops.
Jan 13, 2026
97793cf
Unify handling of conv tensor types between fwd and bwd directions.
Jan 13, 2026
1d51979
Unify block transfer for fwd and bwd directions. Rename ThreadSliceDi…
Jan 13, 2026
a8f1d44
Make BlockTransferDescriptor concept parametrized. Introduce a common…
Jan 13, 2026
bf57fbf
clang-format
Jan 13, 2026
9b58c20
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-bwd-w…
Jan 13, 2026
faf9126
Improve dispatcher error messages. Fix builder smoke tests.
Jan 13, 2026
f9f3844
Merge remote-tracking branch 'origin/vpietila/ckb-bwd-weight-factorie…
Jan 13, 2026
6e38ca6
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-refac…
Jan 14, 2026
6210a83
Rename thread distribution to thread cluster.
Jan 14, 2026
07608d1
Rename LDS transfer related assets.
Jan 14, 2026
3f7b250
Small concepts clean-up.
Jan 14, 2026
096592e
Refactor conv algorithms into more categorized form.
Jan 14, 2026
75d20e0
clang-format
Jan 14, 2026
346b3fa
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-refac…
Jan 14, 2026
5d90ee1
Merge branch 'develop' into vpietila/ckb-refactor-warp-gemm-descriptors
vpietila-amd Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,62 +30,57 @@ concept ThreadBlockDescriptor = requires(T t) {

// Concept for parameters that describe a gridwise XDL GEMM problem.
template <typename T>
concept GridwiseXdlGemmDescriptor = requires(T t) {
{ t.m_per_xdl } -> SizeType;
{ t.n_per_xdl } -> SizeType;
{ t.m_xdl_per_wave } -> SizeType;
{ t.n_xdl_per_wave } -> SizeType;
concept WarpGemmDescriptor = requires(T t) {
{ t.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
{ t.gemm_m_per_instruction } -> SizeType;
{ t.gemm_n_per_instruction } -> SizeType;
{ t.gemm_m_iters_per_wave } -> SizeType;
{ t.gemm_n_iters_per_wave } -> SizeType;
};

// Concept for parameter that describe block GEMM problem.
// Concept for parameters that describe the GEMM pipeline.
template <typename T>
concept BlockGemmPipelineDescriptor = requires(T t) {
concept GemmPipelineDescriptor = requires(T t) {
{ t.num_conv_groups_to_merge } -> SizeType;
{ t.num_gemm_k_prefetch_stages } -> SizeType;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};

// Concept for parameters that describe a gridwise WMMA GEMM problem.
template <typename T>
concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.k1 } -> SizeType;
{ t.m_per_wmma } -> SizeType;
{ t.n_per_wmma } -> SizeType;
{ t.m_wmma_per_wave } -> SizeType;
{ t.n_wmma_per_wave } -> SizeType;
};

// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor3D = requires(T t) {
concept InputTileThreadClusterDescriptor3D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
};

template <typename T>
concept BlockTransferDescriptor4D = requires(T t) {
concept InputTileThreadClusterDescriptor4D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
{ t.k_batch_size } -> SizeType;
};

template <typename T, size_t ThreadClusterRank>
concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D<T>) ||
(ThreadClusterRank == 4 && BlockTransferDescriptor4D<T>);
concept InputTileThreadClusterDescriptor =
(ThreadClusterRank == 3 && InputTileThreadClusterDescriptor3D<T>) ||
(ThreadClusterRank == 4 && InputTileThreadClusterDescriptor4D<T>);

// Concept for thread cluster dimensions for GEMM output tensor.
template <typename T>
concept ThreadClusterDescriptor = requires(T t) {
{ t.m_block } -> SizeType;
{ t.m_wave_per_xdl } -> SizeType;
{ t.n_block } -> SizeType;
{ t.n_wave_per_xdl } -> SizeType;
concept OutputTileThreadClusterDescriptor = requires(T t) {
{ t.gemm_m_block_size } -> SizeType;
{ t.gemm_m_per_block } -> SizeType;
{ t.gemm_n_block_size } -> SizeType;
{ t.gemm_n_per_block } -> SizeType;
};

// Concept for the LDS transfer for the convolution input tensors.
template <typename T>
concept LdsTransferDescriptor = requires(T t) {
{ t.global_memory_vector_load_size } -> SizeType;
{ t.src_vector_dim } -> SizeType;
{ t.src_scalar_per_vector } -> SizeType;
{ t.lds_dst_scalar_per_vector } -> SizeType;
Expand Down Expand Up @@ -172,45 +167,18 @@ concept SpecifiesTileThreadBlock = requires {
{ T::thread_block } -> TileThreadBlockDescriptor;
};

// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseFwdXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> SizeType;
{ t.bk1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};

// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
{ t.k1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};

// Concept to check if a struct specifies gridwise XDL GEMM info.
// Concept to check if a struct specifies warp GEMM info.
template <typename T>
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
};

// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
};

// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
concept SpecifiesWarpGemm = requires {
{ T::warp_gemm } -> WarpGemmDescriptor;
};

// Concept to check if a struct specifies convolution input and output block transfer info.
template <typename T, size_t ThreadClusterRank = 3>
concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
concept SpecifiesThreadClusters = requires(T t) {
{ T::transfer.a.thread_cluster } -> InputTileThreadClusterDescriptor<ThreadClusterRank>;
{ T::transfer.b.thread_cluster } -> InputTileThreadClusterDescriptor<ThreadClusterRank>;
{ T::transfer.c.thread_cluster } -> OutputTileThreadClusterDescriptor;
};

// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
Expand All @@ -232,8 +200,8 @@ concept SpecifiesLdsTransfer = requires(T t) {
// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.a.thread_cluster_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.thread_cluster_access_order } -> AccessOrderDescriptor;
};

// Concept to check if a struct specifies source access order info.
Expand All @@ -245,13 +213,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {

// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor;
};

template <typename T>
concept SpecifiesGridwiseGemmPipeline = requires {
{ T::pipeline_version } -> std::convertible_to<PipelineVersion>;
concept SpecifiesGemmPipeline = requires {
{ T::gemm_pipeline } -> GemmPipelineDescriptor;
};

// Concept to check if struct specifies block GEMM (CK Tile).
Expand Down Expand Up @@ -307,16 +270,6 @@ concept SpecifiesNumGroupsToMerge = requires {
{ T::num_conv_groups_to_merge } -> SizeType;
};

template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};

template <typename T>
concept SpecifiesGenericInstance = !requires {
{ T::specialization };
};

template <typename T>
concept SpecifiesTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
Expand All @@ -333,38 +286,58 @@ template <typename T>
concept TransposeTransferWellDefinedIfProvided =
!HasTransposeTransfer<T> || SpecifiesTransposeTransfer<T>;

template <typename T>
concept SpecifiesGemmBatchOptions = requires {
{ T::num_conv_groups_to_merge } -> SizeType;
};

/******************************************** */
/* Algorithm specialization concepts */
/******************************************** */
template <typename T>
concept SpecifiesLargeTensorSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
requires !!(T::specialization & ConvAlgorithmSpecialization::LARGE_TENSOR);
};

template <typename T>
concept SpecifiesReferenceAlgorithm = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
requires !!(T::specialization & ConvAlgorithmSpecialization::REFERENCE);
};

template <typename T>
concept SpecifiesTwoStageSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
requires !!(T::specialization & ConvAlgorithmSpecialization::TWO_STAGE);
};

template <typename T>
concept SpecifiesMultipleDSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D;
requires !!(T::specialization & ConvAlgorithmSpecialization::MULTIPLE_D);
};

template <typename T>
concept SpecifiesPipelineV3 = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires !!(T::specialization & ConvAlgorithmSpecialization::PIPELINE_V3);
};

template <typename T>
concept SpecifiesGenericInstance = !requires {
{ T::specialization };
} || requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires !!(T::specialization == ConvAlgorithmSpecialization::NONE);
};

template <auto Algorithm>
concept SpecifiesXdl =
requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::XDL; };

template <auto Algorithm>
concept SpecifiesWmma =
requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; };

template <auto Algorithm>
concept SpecifiesValidWarpGemm = SpecifiesXdl<Algorithm> || SpecifiesWmma<Algorithm>;

/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
Expand Down
Loading