Skip to content
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ run_test_config(){
run_default_fa 1 attention/test_kv_cache.py
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_cast_mxfp4.py
run_default_fa 1 triton_kernels/test_grouped_gemm.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
Expand Down
192 changes: 192 additions & 0 deletions tests/pytorch/triton_kernels/test_cast_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add this pytest into our ci script (somewhere near

run_default_fa 1 triton_kernels/test_norms.py
) otherwise it won't be tested

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# License for AMD contributions = MIT. See LICENSE for more information

import math
import pytest
import torch
import numpy as np
import os

os.environ["NVTE_USE_CAST_TRANSPOSE_TRITON"] = "1"

from transformer_engine.pytorch.tensor.mxfp4_tensor import MXFP4Quantizer, MXFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.triton_kernels.cast import te_quantize_triton
from test_common import te_compare_results, fill_uniform


def mxfp4_quantize_cpu(input_tensor, axis='row'):
"""CPU reference for MXFP4 quantization matching Triton kernel behavior with shuffle."""
original_shape = input_tensor.shape
if input_tensor.dim() > 2:
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])

M, N = input_tensor.shape

if axis == 'col':
input_tensor = input_tensor.t().contiguous()
M, N = N, M

data = input_tensor.cpu().float().numpy()

BLOCK_SIZE = 32
assert N % BLOCK_SIZE == 0, f"N={N} must be divisible by {BLOCK_SIZE}"

num_blocks = N // BLOCK_SIZE

# E2M1 FP4 lookup table
fp4_values = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])

# Reshape to blocks: [M, num_blocks, BLOCK_SIZE]
data_blocks = data.reshape(M, num_blocks, BLOCK_SIZE)
amax_blocks = np.max(np.abs(data_blocks), axis=2)

# Triton's amax rounding: (amax + 0x200000) & 0xFF800000
amax_int = amax_blocks.astype(np.float32).view(np.uint32)
amax_int = ((amax_int + 0x200000) & 0xFF800000).astype(np.uint32)
amax_rounded = amax_int.view(np.float32)

# E8M0 scale computation: floor(log2(amax)) - 2 + 127
scale_unbiased = np.floor(np.log2(np.maximum(amax_rounded, 1e-45))) - 2
scale_unbiased = np.clip(scale_unbiased, -127, 127)
scales = (scale_unbiased + 127).astype(np.uint8)
scales = np.where(amax_blocks == 0, 0, scales)

# Scale values for quantization
scale_vals = np.where(scales[:, :, None] > 0,
2.0 ** (-(scales[:, :, None] - 127)),
1.0)

scaled_blocks = data_blocks * scale_vals

# Quantize to FP4
signs = (scaled_blocks < 0).astype(np.uint8)
abs_vals = np.abs(scaled_blocks)
diffs = np.abs(abs_vals[:, :, :, None] - fp4_values[None, None, None, :])
indices = np.argmin(diffs, axis=3).astype(np.uint8)
fp4_encoded = (signs << 3) | indices

fp4_flat = fp4_encoded.reshape(M, N)

# Pack: (odd_col << 4) | even_col
fp4_even = fp4_flat[:, 0::2]
fp4_odd = fp4_flat[:, 1::2]
fp4_packed = ((fp4_odd << 4) | fp4_even).astype(np.uint8)

def cdiv(a, b): return (a + b - 1) // b

scale_M_pad = cdiv(M, 256) * 256
scale_N_pad = cdiv(num_blocks, 8) * 8
scales_padded = np.full((scale_M_pad, scale_N_pad), 127, dtype=np.uint8)

# Copy scales directly (no data shuffle support in Triton kernel)
scales_padded[:M, :num_blocks] = scales

fp4_packed_torch = torch.from_numpy(fp4_packed).to(input_tensor.device)
scales_torch = torch.from_numpy(scales_padded).to(input_tensor.device)

return fp4_packed_torch, scales_torch


@pytest.mark.parametrize("shape", [
(128, 128),
(256, 256),
(256, 1024),
(2048, 6144),
(16384, 128),
(32768, 160),
(4096, 1632),
(8, 32, 1024),
(16, 8, 4, 512),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some prime numbers like

{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP4 requires dimensions divisible by 32 for per-block scaling compatibility with AITER gemm_a4w4. I have added the shapes that should throw a valid and expected assertion error.

# MXFP4 requires: shape[-1] % 32 == 0 and prod(shape[:-1]) % 32 == 0
(32, 3221), # Last dimension 3221 (prime) not divisible by 32
(2333, 32), # First dimension 2333 (prime) not divisible by 32 when flattened
(1481, 677), # Both dimensions are primes, neither divisible by 32
])
@pytest.mark.parametrize("in_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(("rowwise", "columnwise"), [
(True, True),
(False, True),
(True, False)
])
@pytest.mark.parametrize("shuffle_B_matrix", [False, True])
def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix):
"""Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle.
MXFP4 requires dimensions divisible by 32 for per-block scaling compatibility with AITER gemm_a4w4.
Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kept to ensure API consistency between Triton and the upcoming hip kernel for which I'll create a separate PR. In the hip kernel we were able to fuse the shuffle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hip vs triton flow
Input: BF16 [M, N]

MXFP4Quantizer.update_quantized()

tex.cast_transpose_mxfp4_fused_shuffle() [Single HIP kernel]

├─→ Rowwise FP4 [M, K/2] (MFMA shuffled)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (MFMA shuffled)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)

AITER gemm_a4w4 (zero-copy)

vs

Input: BF16 [M, N]

MXFP4Quantizer.update_quantized()

te_cast_transpose_mxfp4_triton() [Triton JIT kernel]

├─→ Rowwise FP4 [M, K/2] (linear layout)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (linear layout)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)

aiter.ops.shuffle.shuffle_weight() [External call]

FP4 data → MFMA layout

AITER gemm_a4w4

"""
if shuffle_B_matrix:
pytest.skip("FP4 data shuffle not yet supported in Triton kernel")

input_tensor = fill_uniform(shape, dtype=in_dtype)
quantizer = MXFP4Quantizer(
rowwise=rowwise,
columnwise=columnwise,
shuffle_B_matrix_for_aiter=shuffle_B_matrix
)

# Test invalid shapes are rejected
if not quantizer.is_quantizable(input_tensor):
assert not quantizer.is_quantizable(input_tensor), \
f"is_quantizable() should return False for invalid shape {shape}"
with pytest.raises(AssertionError, match="must be divisible by"):
quantizer.make_empty(shape, dtype=in_dtype)
return

out = quantizer.make_empty(input_tensor.shape, dtype=in_dtype)
quantized_out = te_quantize_triton(input_tensor, quantizer=quantizer, output=out)

# Tolerance: allow 1 nibble diff for rare edge cases near FP4 boundaries
data_atol = 20.0 if in_dtype != torch.float32 else 16.0
scale_atol = 2.0 if in_dtype != torch.float32 else 1.0
Comment on lines +141 to +142
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:

void adjust_ref_for_e8m0_scale_error(const std::string &name,


if rowwise:
ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='row')
M = math.prod(input_tensor.shape[:-1])
K = input_tensor.shape[-1]
num_blocks = K // MXFP4_BLOCK_SCALING_SIZE

te_compare_results(
quantized_out._rowwise_data.view(torch.uint8),
ref_data,
atol=data_atol,
rtol=0.0,
msg="rowwise FP4 data mismatch",
use_torch_semantics=True
)

# Compare only valid (non-padded) region - no shuffle extraction needed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is fp4 shuffle?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp4 shuffle basically rearranges [M, K/2] linear layout → MFMA instruction layout (16×16).

The currently flow training workflow if TE MXFP4 Quantization Kernel is used is as follows
TE Triton Kernel → Linear FP4 [N, K/2] → aiter.ops.shuffle_weight() → MFMA FP4 → aiter.gemm_a4w4()

You can find the shuffle code in aiter/aiter/ops/shuffle.py

te_compare_results(
quantized_out._rowwise_scale.view(torch.uint8)[:M, :num_blocks],
ref_scale[:M, :num_blocks],
atol=scale_atol,
rtol=0.0,
msg="rowwise E8M0 scales mismatch",
use_torch_semantics=True
)

if columnwise:
ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='col')
M = math.prod(input_tensor.shape[:-1])
K = input_tensor.shape[-1]
num_blocks = M // MXFP4_BLOCK_SCALING_SIZE

te_compare_results(
quantized_out._columnwise_data.view(torch.uint8),
ref_data,
atol=data_atol,
rtol=0.0,
msg="columnwise FP4 data mismatch",
use_torch_semantics=True
)

# Compare only valid (non-padded) region - no shuffle extraction needed
te_compare_results(
quantized_out._columnwise_scale.view(torch.uint8)[:K, :num_blocks],
ref_scale[:K, :num_blocks],
atol=scale_atol,
rtol=0.0,
msg="columnwise E8M0 scales mismatch",
use_torch_semantics=True
)
3 changes: 2 additions & 1 deletion transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:

Image

pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
Expand Down
Loading