Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4d77856
Make some functions return void explicitly instead of auto
SamiAario-AMD Dec 6, 2025
9691569
Use decltype for consistency in Interwave variant of BlockGemmImpl
SamiAario-AMD Nov 21, 2025
bda5a7a
Add braces
SamiAario-AMD Nov 19, 2025
825d17c
Fix a comment
SamiAario-AMD Dec 11, 2025
ca71cd7
Reduce the scope of KPack in MakeALdsBlockDescriptor
SamiAario-AMD Dec 17, 2025
994b8f4
Minor refactoring of load_interleaved_pk_type
SamiAario-AMD Nov 12, 2025
74533b4
Rename load_interleaved_pk_type to load_and_convert_tile
SamiAario-AMD Nov 27, 2025
3a094e2
Include ck_tile/core.hpp in load_interleaved_pk_type.hpp for better I…
SamiAario-AMD Nov 26, 2025
cfa11f2
Rename InterleavedPKTypeLoader to ConverterLoader, and load_int4_tile…
SamiAario-AMD Nov 27, 2025
9559a93
Make explicit that the tile window argument to load_tile_with_element…
SamiAario-AMD Dec 12, 2025
9633d3f
In GetAWindows and GetBWindows, use DataType from LDS tensor view
SamiAario-AMD Dec 17, 2025
9af4498
Remove the defaults for SrcDataType and DstDataType in GemmPipelineAg…
SamiAario-AMD Jan 7, 2026
514035e
In BQuantGemmPipelineAgBgCrCompV3, always convert BDatatype pk_int4_t…
SamiAario-AMD Jan 7, 2026
3d55a1e
No need to specify SrcDataType in load_and_convert_tile as WarpWindow…
SamiAario-AMD Dec 16, 2025
63a4559
No need to specify DstDataType in load_and_convert_tile as WarpTile k…
SamiAario-AMD Dec 16, 2025
8fc4030
Add an instance of load_tile_transpose that takes a reference to the …
SamiAario-AMD Jan 2, 2026
3216110
Remove an unused overload of load_tile_transpose_with_offset
SamiAario-AMD Jan 2, 2026
ca17ac3
When possible, use the overload of load_tile_transpose that does not …
SamiAario-AMD Jan 2, 2026
2edd077
Adjust whitespace with clang-format
SamiAario-AMD Jan 7, 2026
b91efe5
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 7, 2026
0a4388d
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 8, 2026
e62c96f
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 8, 2026
ea4e543
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 14, 2026
c020a42
Fix a build break introduced when merging
SamiAario-AMD Jan 14, 2026
35c620e
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 14, 2026
2ab79eb
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
// TODO: Tile windows should work with unknown number of params
// Load element_wise API works only when the input type is a tuple-type
return tile_windows[number<0>{}].load(
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
Expand All @@ -85,12 +85,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

/**
Expand Down Expand Up @@ -131,7 +131,7 @@ template <typename T,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
CK_TILE_DEVICE void load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
Expand Down
54 changes: 40 additions & 14 deletions include/ck_tile/core/tensor/load_tile_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,25 +373,27 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
Expand All @@ -401,21 +403,17 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
CK_TILE_DEVICE void load_tile_transpose_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};

constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
Expand All @@ -442,8 +440,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});

return out_tensor;
}

/**
Expand All @@ -455,23 +451,45 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void
load_tile_transpose(DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
}

template <
typename BottomTensorView_,
typename WindowLengths_,
Expand All @@ -488,7 +506,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));

load_tile_transpose_with_offset(out_tensor, tile_window, 0);

return out_tensor;
}

} // namespace ck_tile
18 changes: 9 additions & 9 deletions include/ck_tile/core/tensor/tile_window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,32 +182,32 @@ struct tile_window_with_static_distribution
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
tile_windows,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}

template <typename DistributedTensor,
typename TileWindow_,
typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
Expand All @@ -218,14 +218,14 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];

static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
Expand All @@ -236,7 +236,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
return tile_windows[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/batched_contraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/batched_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#pragma once

#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,43 @@

#pragma once

#include "ck_tile/core/config.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"

namespace ck_tile {

template <typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
struct ConverterLoader
{
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src)
{
const element_wise::PassThroughPack8 elementwise_op{};

static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
const auto tmp = load_tile(src);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmp is not the best name. What about naming it src and the window src_window ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this in a separate PR as discussed.


using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
const element_wise::PassThroughPack8 elementwise_op{};

elementwise_op(dst.get_thread_buffer().template get_as<DstVectorType>()(i),
tmp.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};

template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
template <index_t UnaryOpSize, bool LoadTranspose = false, typename WarpTile, typename WarpWindow>
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
if constexpr(std::is_same_v<typename WarpWindow::Base::DataType, pk_int4_t>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
ConverterLoader<typename WarpTile::DataType, UnaryOpSize>::load_interleaved_pk_type(dst,
src);
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
load_tile_transpose(dst, src);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot seems to agree with me that assignment cannot be optimized by the compiler in such a way that the creation of temporaries is avoided, when assigning to a complex type. It then ought to make sense to avoid assignment when possible, and pass the object to be assigned to as a reference instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamiAario-AMD Have you checked the assembly for this change? Does the register usage is different?

}
else
{
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/ops/elementwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/fmha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Loading