Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
26ff5be
Add unswizzling functions for scaling factors in swizzle module
int-smart Mar 4, 2026
6a064cf
Add swizzle/unswizzle roundtrip test for scaling factors
int-smart Mar 4, 2026
7d1567e
Added another unswizzling functionality test for scaling factors
int-smart Mar 4, 2026
fd3ff05
Moved swizzle_row_scaling_kernel implementation at its original place
int-smart Mar 4, 2026
64b86d7
Add multi-tensor unswizzling functions for scaling factors
int-smart Mar 4, 2026
621bc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
17dbb33
Added greptile suggestions
int-smart Mar 5, 2026
bd0e4e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
57c8532
Removed unused check from tests and reading input directly as const r…
int-smart Mar 5, 2026
d7b6d2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
672d517
Refactor unswizzling functions and update test cases for scaling factors
int-smart Mar 13, 2026
4410e9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
8e272a7
Refactor unswizzling scaling factors to use a launch function
int-smart Mar 17, 2026
bc1fb51
Change unswizzling to use output as gt.
int-smart Mar 17, 2026
cf262c0
Refactor unswizzling scaling factors to improve input validation and …
int-smart Mar 19, 2026
f592746
Fix multi_tensor_unswizzle_scaling_factors to correctly reference out…
int-smart Mar 20, 2026
38cec8c
Enhance swizzle tests and unswizzling validation
int-smart Mar 21, 2026
abb0b29
Fix typos and update validation checks in swizzle.cu
int-smart Mar 21, 2026
dbf6c34
Update validation checks in multi_tensor_unswizzle_scaling_factors to…
int-smart Mar 21, 2026
ed009f2
Typo
int-smart Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,35 @@ void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output,
}
}

template <int SF_TILE_DIM_M, int SF_TILE_DIM_K, bool row_scaling>
void compute_ref_unswizzle(const uint8_t *h_input, uint8_t *h_output,
const size_t M, const size_t K) {

constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4;
constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4;
constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K;

for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {

int tile_id_m = m / SF_TILE_DIM_M;
int tile_id_k = k / SF_TILE_DIM_K;
int m_in_tile = m % SF_TILE_DIM_M;
int k_in_tile = k % SF_TILE_DIM_K;

int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M;
int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile;

int tile_input_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE;
int in_index = tile_input_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile;
if constexpr(row_scaling)
h_output[k + m * K] = h_input[in_index];
else
h_output[k * M + m] = h_input[in_index];
}
}
}

void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

Expand Down Expand Up @@ -110,6 +139,66 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
}
}

void performTestUnswizzle1D(const size_t M, const size_t K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}

if (!rowwise && !columnwise) {
GTEST_SKIP() << "TEST SKIPPED, Either rowwise or columnwise scaling mode must be true.";
}
if (rowwise && columnwise) {
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + " is not implemented.";
}

DType dtype = DType::kFloat8E4M3;

const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};

Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
input.set_with_gemm_swizzled_scales(true);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);

fillUniform(&input);

// Use the actual padded compact scale shape from the tensor for both the reference
// and the comparison. This correctly covers padded cases where M is not a multiple
// of 128 or K/32 is not a multiple of 4.
const auto padded_scale_shape = rowwise
? input.rowwise_scale_inv_shape()
: input.columnwise_scale_inv_shape();
const size_t padded_dim0 = padded_scale_shape.data[0];
const size_t padded_dim1 = padded_scale_shape.data[1];
std::unique_ptr<uint8_t[]> ref_output = std::make_unique<uint8_t[]>(padded_dim0 * padded_dim1);

nvte_unswizzle_scaling_factors(input.data(), output.data(), 0);

if (rowwise)
compute_ref_unswizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), padded_dim0, padded_dim1);
else
compute_ref_unswizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), padded_dim1, padded_dim0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

output.to_cpu();
if (rowwise) {
compareResults("output_unswizzle", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), padded_dim0 * padded_dim1);
} else {
compareResults("output_unswizzle", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), padded_dim0 * padded_dim1);
}
}

class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};


Expand All @@ -126,6 +215,21 @@ TEST_P(SwizzleTestSuite, TestSwizzle) {
transa);
}

class UnswizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<size_t, size_t>, std::pair<bool, bool>, bool>> {};

TEST_P(UnswizzleTestSuite, TestUnswizzle) {
using namespace transformer_engine;
using namespace test;

const auto data_shape = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());

performTestUnswizzle1D(data_shape.first, data_shape.second,
scaling_mode.first, scaling_mode.second,
transa);
}

namespace {

std::vector<std::pair<int, int>> num_tiles = {
Expand All @@ -138,6 +242,24 @@ std::vector<std::pair<int, int>> num_tiles = {
{65, 259},
};

// Raw {M, K} data shapes for unswizzle tests. Includes aligned cases (scale dims
// already multiples of 128 and 4) and padded cases where M or K/32 are not yet
// aligned, forcing the compact scale_inv to carry a padded tail.
// All K values must be multiples of 32 (MXFP8 block size).
std::vector<std::pair<size_t, size_t>> unswizzle_data_shapes = {
// Aligned: scale dims are already multiples of 128 and 4
{128, 128},
{128, 16896}, // K = 132 * 128, large K
{16896, 128}, // M = 132 * 128, large M
// M-padding only: M not a multiple of 128 (scale-M needs padding to 256)
{160, 128},
// scale-K padding only: K/32 = 3, padded to 4
{128, 96},
// Both M and scale-K need padding
{160, 96},
{16896, 16896},
};

std::vector<std::pair<bool, bool>> scaling_mode = {
{true, false},
Comment on lines +249 to 264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Roundtrip test only covers aligned matrix dimensions

performTestSwizzleUnswizzleRoundtrip is instantiated exclusively with the existing num_tiles vector, which always produces M = num_tiles_M * MAT_TILE_DIM_M — values that are exact multiples of 128 (the scale-M alignment). The standalone performTestUnswizzle1D intentionally adds padded shapes (e.g., M=160, K=96) via unswizzle_data_shapes, but no equivalent padded cases exist for the roundtrip.

If the output-size validation or padding-mask logic ever diverges between the swizzle and unswizzle paths for non-aligned M/K, the roundtrip test would pass while standalone tests fail (or vice-versa). Consider adding a few padded shapes (e.g., {4, 3} tile-count pairs or raw {160, 96} shapes) to num_tiles or creating a separate data-shape vector for the roundtrip suite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@int-smart Is there a reason for that difference between the tests?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to keep one test similar to swizzle tests which has aligned test cases. Moved to using unswizzled_data_shapes for roundtrip as well with the aligned cases as part of the unswizzled_data_shapes.

{false, true}
Expand All @@ -164,3 +286,130 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param));
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
UnswizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(unswizzle_data_shapes),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<UnswizzleTestSuite::ParamType>& info) {
std::string name = "MK" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});

void performTestSwizzleUnswizzleRoundtrip(const size_t M, const size_t K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}

if (!rowwise && !columnwise) {
GTEST_SKIP() << "TEST SKIPPED, Either rowwise or columnwise scaling mode must be true.";
}
if (rowwise && columnwise){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + " is not implemented.";
}

DType dtype = DType::kFloat8E4M3;

const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};
const size_t logical_dim0 = data_shape[0] / SF_MODE_X;
const size_t logical_dim1 = data_shape[1] / SF_MODE_Y;

Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
swizzled.set_with_gemm_swizzled_scales(true);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);

fillUniform(&input);

// fillUniform fills all scale_inv entries including the padded region with random bytes.
// After swizzle, the swizzle kernel zeroes padded positions in the swizzled output, so
// after unswizzle those positions come back as zero in the compact output. Zero them in
// the input now so the full-buffer comparison is valid.
const auto padded_scale_shape = rowwise
? input.rowwise_scale_inv_shape()
: input.columnwise_scale_inv_shape();
const size_t padded_dim0 = padded_scale_shape.data[0];
const size_t padded_dim1 = padded_scale_shape.data[1];

if (padded_dim0 != logical_dim0 || padded_dim1 != logical_dim1) {
auto* scale_ptr = rowwise
? input.rowwise_cpu_scale_inv_ptr<uint8_t>()
: input.columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t r = 0; r < padded_dim0; r++) {
for (size_t c = 0; c < padded_dim1; c++) {
if (r >= logical_dim0 || c >= logical_dim1) {
scale_ptr[r * padded_dim1 + c] = 0;
}
}
}
input.from_cpu();
}

nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0);
nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

input.to_cpu();
output.to_cpu();
if (rowwise) {
compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr<uint8_t>(),
input.rowwise_cpu_scale_inv_ptr<uint8_t>(), padded_dim0 * padded_dim1);
} else {
compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr<uint8_t>(),
input.columnwise_cpu_scale_inv_ptr<uint8_t>(), padded_dim0 * padded_dim1);
}
}

class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<size_t, size_t>, std::pair<bool, bool>, bool>> {};

TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) {
using namespace transformer_engine;
using namespace test;

const auto data_shape = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());

performTestSwizzleUnswizzleRoundtrip(data_shape.first, data_shape.second,
scaling_mode.first, scaling_mode.second,
transa);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleUnswizzleRoundtripTestSuite,
::testing::Combine(
::testing::ValuesIn(unswizzle_data_shapes),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<SwizzleUnswizzleRoundtripTestSuite::ParamType>& info) {
std::string name = "roundtrip_MK" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});
32 changes: 30 additions & 2 deletions transformer_engine/common/include/transformer_engine/swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ extern "C" {
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
* - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);

Expand All @@ -40,11 +40,39 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
* - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major
*
* \param[in] input Input tensor with swizzled scale_inv.
* \param[in,out] output Output tensor which hosts non-swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major in output.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major
*
* \param[in] inputs Input tensors with swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts non-swizzled scale_inv.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major in output.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block-scaled tensor.
Expand Down
Loading