Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {
};

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{16, 128, 16},
{32, 128, 32},
{768, 3072, 4096},
};

Expand Down Expand Up @@ -345,8 +347,11 @@ void performTest(const TestParams& params) {
if (!has_fp8) {
GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types";
}
if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) {
GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32";
if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it hipblasLt limitation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are the values that hipblastlt team provided to us. I tested just in case, but nothing smaller that 128 works for k.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 32x128x32 config needed with 16x128x16 then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it makes sense to keep. This allows us to test a TE acceptable size with 32 while also ensuring unpadding and hipBLASlt is working with 16.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I'd change 32x128x32 to 32x128x16 to test they work together

GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
}
}

Expand Down Expand Up @@ -560,8 +565,11 @@ void performDqTest(const TestParams &params) {
GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected";
GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected";

if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) {
GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32";
if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
}

cudaDeviceProp prop;
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,

scale_inv_meta ret_rowwise, ret_colwise;

#ifdef __HIP_PLATFORM_AMD__
auto block_alignment = std::vector<size_t>{1ul, 1ul};
#else
auto block_alignment = std::vector<size_t>{128ul, 4ul};
#endif
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
Expand Down Expand Up @@ -181,12 +185,20 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,

{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
#ifdef __HIP_PLATFORM_AMD__
auto scale_dim_1 = DIVUP(last_dim, static_cast<size_t>(128));
#else
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
#endif
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
#ifdef __HIP_PLATFORM_AMD__
auto scale_dim_1 = DIVUP(first_dim, static_cast<size_t>(128));
#else
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
#endif
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
Expand All @@ -207,12 +219,20 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,

{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
#ifdef __HIP_PLATFORM_AMD__
auto scale_dim_1 = first_dim;
#else
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
#endif
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
#ifdef __HIP_PLATFORM_AMD__
auto scale_dim_1 = last_dim;
#else
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
#endif
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
Expand Down
7 changes: 7 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,17 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;

// [128,4] rowwise and [4,128] colwise alignment requirement
#ifdef __HIP_PLATFORM_AMD__
constexpr size_t scale_tensor_alignment_X_rowwise = 1;
constexpr size_t scale_tensor_alignment_Y_rowwise = 1;
constexpr size_t scale_tensor_alignment_X_colwise = 1;
constexpr size_t scale_tensor_alignment_Y_colwise = 1;
#else
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
#endif

inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
Expand Down
5 changes: 5 additions & 0 deletions tests/pytorch/references/blockwise_quantizer_reference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -8,6 +10,7 @@
from typing import Optional, Protocol, Tuple
from references.quantize_scale_calc import scale_from_amax_tensor

from torch.utils.cpp_extension import IS_HIP_EXTENSION

@dataclasses.dataclass()
class QuantizeResult:
Expand Down Expand Up @@ -36,6 +39,8 @@ def munge_scale_shapes_for_backend(
def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor:
if transpose:
s = s.transpose(-1, -2).contiguous()
if IS_HIP_EXTENSION: # HIP does not use scale padding
return s
M, K = s.shape
if K % 4 == 0:
return s
Expand Down
74 changes: 73 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import ModelConfig
Expand Down Expand Up @@ -913,6 +913,78 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
)
torch.cuda.synchronize()

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("K", [128])
@pytest.mark.parametrize("M", [32])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use non multiple of 32 to test this path is unpadding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We require block sizes of 32 at the python level, so not possible to do a non-multiple. We are padding scales, so we will see a rowwise scale of (1,4) padded to (128,4), and a colwise scale of (4,1) being padded to (4,128).

@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_mxfp8_gemm_with_padding(N, K, M, datatype):
"""Test the unpadding functionality in rocm"""
dtype = tex.DType.kFloat8E4M3
quantizer = MXFP8Quantizer(dtype)

input_dtype = torch.randn(M, K, device="cuda", dtype=datatype)
weight_dtype = torch.randn(N, K, device="cuda", dtype=datatype)

input_data = quantizer.make_empty((M, K), device="cuda")
weight_data = quantizer.make_empty((N, K), device="cuda")

quantizer.update_quantized(input_dtype, input_data)
quantizer.update_quantized(weight_dtype, weight_data)

out_ref = general_gemm(
weight_data,
input_data,
get_workspace(),
datatype,
bias=None,
use_split_accumulator=False,
)
torch.cuda.synchronize()

row_scale_inv = input_data._rowwise_scale_inv
rows, cols = row_scale_inv.shape
row_padded_scale_inv = torch.zeros((128, 4), dtype=row_scale_inv.dtype, device="cuda")
row_padded_scale_inv[:rows, :cols] = row_scale_inv

col_scale_inv = input_data._columnwise_scale_inv
rows, cols = col_scale_inv.shape
col_padded_scale_inv = torch.zeros((4, 128), dtype=col_scale_inv.dtype, device="cuda")
col_padded_scale_inv[:rows, :cols] = col_scale_inv


input_padded = MXFP8Tensor(
shape=input_data.shape,
rowwise_data=input_data._rowwise_data.clone(),
rowwise_scale_inv=row_padded_scale_inv,
columnwise_data=input_data._columnwise_data.clone(),
columnwise_scale_inv=col_padded_scale_inv,
fp8_dtype=tex.DType.kFloat8E4M3,
quantizer=quantizer,
dtype=datatype
)

out_pass1 = general_gemm(
weight_data,
input_padded,
get_workspace(),
datatype,
bias=None,
use_split_accumulator=False
)
torch.cuda.synchronize()

assert row_scale_inv.shape == input_padded._rowwise_scale_inv.shape, \
("Shape mismatch in rowwise scales")
assert col_scale_inv.shape == input_padded._columnwise_scale_inv.shape, \
("Shape mismatch in colwise scales")
torch.testing.assert_close(row_scale_inv, input_padded._rowwise_scale_inv,
rtol=1e-7, atol=1e-7, msg="rowwise scale mismatch")
torch.testing.assert_close(col_scale_inv, input_padded._columnwise_scale_inv,
rtol=1e-7, atol=1e-7, msg="colwise scale mismatch")
torch.testing.assert_close(out_pass1[0], out_ref[0],
rtol=1e-2, atol=1e-2, msg="GEMM output mismatch")


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -715,16 +715,22 @@ template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif

#ifndef __HIP_PLATFORM_AMD__
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

#ifndef __HIP_PLATFORM_AMD__
// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
#else
// HIP does not use scale padding
constexpr size_t scale_tensor_alignment_X_rowwise = 1;
constexpr size_t scale_tensor_alignment_Y_rowwise = 1;
constexpr size_t scale_tensor_alignment_X_colwise = 1;
constexpr size_t scale_tensor_alignment_Y_colwise = 1;
#endif

inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
Expand Down
7 changes: 6 additions & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -98,8 +98,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
#ifndef __HIP_PLATFORM_AMD__
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
#else
// HIP does not use scale padding
auto block_alignment = std::vector<size_t>{1ul, 1ul};
#endif
size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_colwise = 32;
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/common/util/rocm_cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const e8m0_t biased_exponent =
ptx::float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp);
// Only single thread writes the computed scaling factor
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) {
const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols;
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y;
const int global_scales_offset_X =
Expand Down Expand Up @@ -297,7 +298,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const bool row_out_of_bounds = row_base >= rows;
if (!(row_out_of_bounds || col_out_of_bounds)) {
scales_colwise[scale_idx] = biased_exponent;
}

const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
#pragma unroll
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/jax/csrc/extensions/misc.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -83,7 +85,11 @@ constexpr struct BlockSize {
constexpr struct Alignment {
size_t x;
size_t y;
#ifndef __HIP_PLATFORM_AMD__
} MXFP8_ALIGNMENT{128, 4};
#else
} MXFP8_ALIGNMENT{1, 1};
#endif

std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);

Expand Down
9 changes: 8 additions & 1 deletion transformer_engine/jax/quantize/scaling_modes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -23,6 +25,7 @@

from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
from ..util import is_hip_extension


__all__ = [
Expand Down Expand Up @@ -366,7 +369,11 @@ def __init__(self, block_dims: Tuple[int]):
block_dims: Dimensions of the scaling blocks
"""
self._block_dims = block_dims
self._block_alignment = (128, 4)
if is_hip_extension():
self._block_alignment = (1, 1)
else:
self._block_alignment = (128, 4)


def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in block scaling.
Expand Down
Loading