Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ Kernel Configuration
:Default: ``0``
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.

.. envvar:: NVTE_NVFP4_ENABLE_4OVER6

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled, either with the corresponding recipe fields or with :envvar:`NVTE_NVFP4_DISABLE_RHT`, :envvar:`NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING`, and :envvar:`NVTE_NVFP4_DISABLE_2D_QUANTIZATION`.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
30 changes: 22 additions & 8 deletions tests/cpp/operator/test_dequantize_nvfp4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data,
OType *output,
size_t rows,
size_t cols,
size_t scale_stride) {
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);
size_t scale_stride,
bool use_4over6) {
constexpr float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f));
constexpr size_t BLOCK_SIZE = 16;
const size_t Mread = cols / BLOCK_SIZE;
const size_t bytes_per_block = BLOCK_SIZE / 2;
Expand Down Expand Up @@ -90,7 +91,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) {
// against a CPU reference computed from the quantized data.
template <typename OutputType>
void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const bool use_4over6) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -99,6 +101,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,

Tensor quantized("quantized", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized.set_nvfp4_4over6(use_4over6);
if (row_scaled_nvfp4) {
quantized.set_tensor_amax_shape({rows});
quantized.set_row_scaled_nvfp4(true);
Expand Down Expand Up @@ -133,7 +136,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
std::make_unique<OutputType[]>(rows * cols);
compute_ref_dequantize_nvfp4<OutputType>(
fp4_data, scales, amax_val, ref_output.get(),
rows, cols, scale_stride);
rows, cols, scale_stride, use_4over6);

auto [atol, rtol] = getTolerances(otype);
compareResults("output_nvfp4", output, ref_output.get(), true, atol, rtol);
Expand All @@ -143,7 +146,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
// Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path.
template <typename OutputType>
void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const bool use_4over6) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -152,6 +156,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,

Tensor quantized_compact("quantized_compact", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_compact.set_nvfp4_4over6(use_4over6);
if (row_scaled_nvfp4) {
quantized_compact.set_tensor_amax_shape({rows});
quantized_compact.set_row_scaled_nvfp4(true);
Expand All @@ -174,6 +179,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
// Create tensor with same FP4 data but swizzled scales
Tensor quantized_swizzled("quantized_swizzled", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_swizzled.set_nvfp4_4over6(use_4over6);
if (row_scaled_nvfp4) {
quantized_swizzled.set_tensor_amax_shape({rows});
quantized_swizzled.set_row_scaled_nvfp4(true);
Expand Down Expand Up @@ -246,6 +252,7 @@ std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
class DequantizeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool,
bool>> {};

TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
Expand All @@ -257,10 +264,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const bool use_4over6 = std::get<3>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6);
);
}

Expand All @@ -270,20 +278,23 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool(),
::testing::Bool()),
[](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor");
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(std::get<3>(info.param) ? "FourOverSix" : "Default");
return name;
}
);

class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool,
bool>> {};

TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
Expand All @@ -295,10 +306,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const bool use_4over6 = std::get<3>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4_swizzled<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6);
);
}

Expand All @@ -308,13 +320,15 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool(),
::testing::Bool()),
[](const testing::TestParamInfo<DequantizeNVFP4SwizzledTestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" +
"Swizzled";
return name;
}
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ class Tensor {
tensor_.set_row_scaled_nvfp4(row_scaled_nvfp4);
}

void set_nvfp4_4over6(bool nvfp4_4over6) {
tensor_.set_nvfp4_4over6(nvfp4_4over6);
}

void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
Expand Down
21 changes: 21 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def check_nvfp4_gemm_versus_reference(
x_columnwise: bool = False,
w_columnwise: bool = False,
row_scaled_nvfp4: bool = False,
use_4over6: bool = False,
):
te_dtype = tex.DType.kFloat4E2M1

Expand Down Expand Up @@ -59,6 +60,7 @@ def check_nvfp4_gemm_versus_reference(
with_rht=False,
with_post_rht_amax=False,
row_scaled_nvfp4=row_scaled_nvfp4,
use_4over6=use_4over6,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
Expand All @@ -68,6 +70,7 @@ def check_nvfp4_gemm_versus_reference(
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
use_4over6=use_4over6,
)

# Quantize x and w
Expand Down Expand Up @@ -123,6 +126,7 @@ def check_nvfp4_gemm_versus_reference(
eps=0.0,
quant_tile_shape=(1, 16),
row_scaled_nvfp4=row_scaled_nvfp4,
use_4over6=use_4over6,
)
w_ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
Expand All @@ -131,6 +135,7 @@ def check_nvfp4_gemm_versus_reference(
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
use_4over6=use_4over6,
)

# Create reference quantized tensors needed by reference GEMM
Expand Down Expand Up @@ -232,6 +237,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
*,
use_bias: bool,
single_output: bool,
use_4over6: bool = False,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
Expand All @@ -249,6 +255,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
with_rht=False,
with_post_rht_amax=False,
row_scaled_nvfp4=True,
use_4over6=use_4over6,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
Expand All @@ -258,6 +265,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
use_4over6=use_4over6,
)

x_nvfp4 = []
Expand Down Expand Up @@ -321,6 +329,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated(
M: int,
K: int,
N: int,
use_4over6: bool = False,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
Expand All @@ -339,6 +348,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated(
with_rht=False,
with_post_rht_amax=False,
row_scaled_nvfp4=True,
use_4over6=use_4over6,
)
x_tensorwise_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
Expand All @@ -348,6 +358,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated(
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
use_4over6=use_4over6,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
Expand All @@ -357,6 +368,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated(
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
use_4over6=use_4over6,
)

x_row_scaled = x_row_scaled_quantizer.update_quantized(
Expand Down Expand Up @@ -417,6 +429,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated(
ids=["rowxrow", "colxrow", "colxcol"],
)
@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
def test_nvfp4_gemm_versus_reference(
M: int,
K: int,
Expand All @@ -428,6 +441,7 @@ def test_nvfp4_gemm_versus_reference(
is_x_columnwise: bool,
is_w_columnwise: bool,
row_scaled_nvfp4: bool,
use_4over6: bool,
):
if row_scaled_nvfp4:
if accumulate:
Expand All @@ -446,6 +460,7 @@ def test_nvfp4_gemm_versus_reference(
x_columnwise=is_x_columnwise,
w_columnwise=is_w_columnwise,
row_scaled_nvfp4=row_scaled_nvfp4,
use_4over6=use_4over6,
)


Expand All @@ -471,6 +486,7 @@ def test_nvfp4_gemm_versus_reference(
@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"])
@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
m_splits: list[int],
k: int,
Expand All @@ -480,6 +496,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
out_dtype: torch.dtype,
use_bias: bool,
single_output: bool,
use_4over6: bool,
):
check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
x_dtype=x_dtype,
Expand All @@ -490,6 +507,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
n=n,
use_bias=use_bias,
single_output=single_output,
use_4over6=use_4over6,
)


Expand All @@ -513,13 +531,15 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
def test_nvfp4_row_scaled_gemm_matches_emulated(
M: int,
K: int,
N: int,
x_dtype: torch.dtype,
w_dtype: torch.dtype,
out_dtype: torch.dtype,
use_4over6: bool,
):
check_nvfp4_row_scaled_gemm_matches_emulated(
x_dtype=x_dtype,
Expand All @@ -528,4 +548,5 @@ def test_nvfp4_row_scaled_gemm_matches_emulated(
M=M,
K=K,
N=N,
use_4over6=use_4over6,
)
Loading