Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7c72ee0
Initial refactor
Micky774 Sep 5, 2025
45e6236
Merge branch 'dev' into zain/triton-dispatch
Micky774 Sep 5, 2025
5b2ea1c
Minor API correction
Micky774 Sep 5, 2025
077b8fc
Corrected atomic behaivor
Micky774 Sep 5, 2025
0011e5f
API update
Micky774 Sep 5, 2025
5ada1bd
Merge branch 'dev' into zain/triton-dispatch
Micky774 Sep 8, 2025
26298c8
Formatting
Micky774 Sep 8, 2025
cc02444
Merge branch 'dev' into zain/triton-dispatch
Micky774 Jan 28, 2026
18eb6e7
Added skip for failing HIP kernels
Micky774 Jan 28, 2026
a72b507
Updated to account for alignment args
Micky774 Jan 28, 2026
a64b5f1
Updated CI script for MI350 runs, minor code cleaning
Micky774 Jan 28, 2026
1d2554c
Streamlined implementation
Micky774 Jan 28, 2026
fd59057
Corrected alignment calculation
Micky774 Jan 28, 2026
bbd4240
Add copyright
Micky774 Jan 28, 2026
92aecf2
Updated alignment calculation
Micky774 Jan 28, 2026
12bb156
Corrected FP8_CS handling
Micky774 Jan 29, 2026
1925039
Corrected layernorm memory access bug
Micky774 Jan 30, 2026
6f9b6c5
Corrected amax dims
Micky774 Jan 30, 2026
dc3ed87
Adjusted amax init
Micky774 Jan 30, 2026
24fbcc4
Updated file names, and copyright
Micky774 Feb 6, 2026
0d6d00f
Corrected MXFP8 testing behavior
Micky774 Feb 6, 2026
499d14b
Update copyright, clarify test, clean imports
Micky774 Feb 9, 2026
c526c8d
Updated test script to respect renaming
Micky774 Feb 9, 2026
26cd12c
Merge branch 'dev' into zain/triton-dispatch
Micky774 Feb 9, 2026
d031266
Update copyrights, clean import
Micky774 Feb 10, 2026
8a5a786
Copyrights
Micky774 Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ run_test_config(){
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_grouped_gemm.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
run_default_fa 1 triton_kernels/test_utils.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 triton_kernels/test_norms.py
NVTE_ROCM_ENABLE_MXFP8=1 NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
run_default_fa 1 test_parallel_cross_entropy.py
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py
NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
Expand Down
79 changes: 39 additions & 40 deletions tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information


import math
import os
import torch
import pytest
from functools import partial
from itertools import product

from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.triton_kernels.norm_common import get_ln_sm_margin
from transformer_engine.pytorch.triton_kernels.utils import get_ln_sm_margin
from transformer_engine.pytorch.triton_kernels.common import (
torch_dtype_to_te_dtype,
te_dtype_to_torch_dtype
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from transformer_engine.pytorch.triton_kernels.rmsnorm import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from transformer_engine.pytorch.triton_kernels.layernorm import (
from transformer_engine.pytorch.triton_kernels.norms_common import (
te_layernorm_bwd_triton,
te_layernorm_fwd_triton,
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform

Expand Down Expand Up @@ -406,11 +406,11 @@ def _compare_output_tensors(
quantization, fp8_dtype
):
tols = dtype_tols(out_triton.dtype if quantization is None else fp8_dtype)
_compare_func = partial(te_compare_results, **tols, use_torch_semantics=True)
compare_func = partial(te_compare_results, **tols, use_torch_semantics=True)

dq_out_triton = out_triton.dequantize()
dq_out_hip = out_hip.dequantize()
_compare_func(
compare_func(
actual=dq_out_triton,
expected=dq_out_hip,
msg=lambda msg: f"Output does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -428,7 +428,7 @@ def _compare_output_tensors(
if not out_hip._transpose_invalid:
# The transpose data are generally uint8 so we must convert
# them for floating point comparison.
_compare_func(
compare_func(
actual=out_triton._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32),
expected=out_hip._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32),
msg=lambda msg: f"Output transpose does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -437,67 +437,66 @@ def _compare_output_tensors(
elif quantization == "mxfp8":
if not isinstance(out_triton, MXFP8Tensor):
raise ValueError(f"Expected a MXFP8Tensor but got {type(out_triton)} instead.")

# TODO(micky774): Figure out if we need to apply the same view
# trick to MXFP8 data as we do to FP8 transpose data.
# I suspect not.
if out_hip._rowwise_data is not None:
_compare_func(
actual=out_triton,
expected=out_hip,
msg=lambda msg: f"Output rowwise data does not match triton <-> hip\n\n{msg}\n",
)
out_triton._rowwise_data = None
assert out_triton._rowwise_data is not None, "Expected rowwise data."
else:
assert out_triton._rowwise_data is None, "Expected no rowwise data."

# We use higher precision for the scales
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
if quantization == "fp8":
_compare_func(
compare_func(
actual=out_triton._scale_inv,
expected=out_hip._scale_inv,
msg=lambda msg: f"Output scale inverse does not match triton <-> hip\n\n{msg}\n",
)
elif quantization == "mxfp8":
has_rscale_triton = out_triton._rowwise_scale_inv is not None
has_rscale_hip = out_hip._rowwise_scale_inv is not None

# The scale_inv values may differ slightly, but will still dequantize close enough to
# pass the earlier comparisons.
compare_func = partial(te_compare_results, atol=1, rtol=0, use_torch_semantics=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mxfp8 data and scale inv comparison, we can reuse the same logic in cpp gtest:

void adjust_ref_for_e8m0_scale_error(const std::string &name,

#ifdef __HIP_PLATFORM_AMD__
if (::testing::Test::HasFatalFailure()) return;
adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr,
ref_output_scales.get(), scales_stride, rows, cols, rowwise,
ref_output_c.get(), otype);
mismatches_scales = 0;
#endif
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We essentially already do this implicitly by relying on the dequantization of the MXFP8Tensors before comparison. While we could handle this explicitly as in the C tests, I don't think that's necessary given that the dequantization behavior has its own testing which passes. Let me know if you have other thoughts on the matter.


# The MXFP8 tensors carry their scale_inv values in a padded
# format, hence we must omit the padded values.
input_shape = out_triton.shape
unpad_rscale_inv_shape = (math.prod(input_shape[:-1]), input_shape[-1] // MXFP8_BLOCK_SCALING_SIZE)
if has_rscale_triton != has_rscale_hip:
msg = "Expected rowwise scale to "
if has_rscale_hip:
msg += "not "
msg += "be None."
raise ValueError(msg)
if has_rscale_triton:
_compare_func(
actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
compare_func(
actual=out_triton._rowwise_scale_inv[:unpad_rscale_inv_shape[0], :unpad_rscale_inv_shape[1]],
expected=out_hip._rowwise_scale_inv[:unpad_rscale_inv_shape[0], :unpad_rscale_inv_shape[1]],
msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n",
)

has_cscale_triton = out_triton._columnwise_scale_inv is not None
has_cscale_hip = out_hip._columnwise_scale_inv is not None
if has_cscale_triton != has_cscale_hip:
msg = "Expected columnwwise scale to "
msg = "Expected columnwise scale to "
if has_cscale_hip:
msg += "not "
msg += "be None."
raise ValueError(msg)
if has_cscale_triton:
_compare_func(
actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
compare_func(
actual=out_triton._columnwise_scale_inv[:unpad_rscale_inv_shape[1], :unpad_rscale_inv_shape[0]],
expected=out_hip._columnwise_scale_inv[:unpad_rscale_inv_shape[1], :unpad_rscale_inv_shape[0]],
msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n",
)


def _compare_quantizers(
self,
quantizer_triton, quantizer_hip,
quantization
):
if quantization is None: return
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)

if quantizer_triton.dtype != quantizer_hip.dtype:
raise ValueError("Expected matching quantizer dtypes, but got "
Expand All @@ -511,12 +510,12 @@ def _compare_quantizers(
raise ValueError(f"Expected matching quantizer {usage} but got {qt_usage=} != {qh_usage=}")

if quantization == "fp8":
_compare_func(
compare_func(
actual=quantizer_triton.scale,
expected=quantizer_hip.scale,
msg=lambda msg: f"Quantizer scale does not match triton <-> hip\n\n{msg}\n",
)
_compare_func(
compare_func(
actual=quantizer_triton.amax,
expected=quantizer_hip.amax,
msg=lambda msg: f"Quantizer amax does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -529,15 +528,15 @@ def _compare_stat_tensors(
norm
):
# We use higher precision for the remaining outputs
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)

_compare_func(
compare_func(
actual=rsigma_triton,
expected=rsigma_hip,
msg=lambda msg: f"rsigma does not match triton <-> hip\n\n{msg}\n",
)
if norm == "layer":
_compare_func(
compare_func(
actual=mu_triton,
expected=mu_hip,
msg=lambda msg: f"mu does not match triton <-> hip\n\n{msg}\n",
Expand Down Expand Up @@ -579,20 +578,20 @@ def _compare_bwd_tensors(
dbeta_triton, dbeta_hip,
norm
):
_compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True)
compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True)

_compare_func(
compare_func(
actual=dx_triton,
expected=dx_hip,
msg=lambda msg: f"dx does not match triton <-> hip\n\n{msg}\n",
)
_compare_func(
compare_func(
actual=dgamma_triton,
expected=dgamma_hip,
msg=lambda msg: f"dgamma does not match triton <-> hip\n\n{msg}\n",
)
if norm == "layer":
_compare_func(
compare_func(
actual=dbeta_triton,
expected=dbeta_hip,
msg=lambda msg: f"dbeta does not match triton <-> hip\n\n{msg}\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information


import torch
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.triton_kernels.common import torch_dtype_to_te_dtype
from transformer_engine.pytorch.triton_kernels.norm_common import get_num_sms
from transformer_engine.pytorch.triton_kernels.utils import get_num_sms

def test_sm_margin():
num_sms = get_num_sms()
Expand Down
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -20,8 +20,12 @@
from ..export import is_in_onnx_export_mode

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton
from ..triton_kernels.norms_common import (
te_layernorm_fwd_triton,
te_layernorm_bwd_triton,
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton
)

def _get_normalization_func(normalization: str, forward: bool):
use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms_common import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton


__all__ = ["LayerNormLinear"]
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright date

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, along with a few others I missed.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -83,8 +83,7 @@
from ...debug.pytorch.debug_state import TEDebugState

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms_common import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

__all__ = ["LayerNormMLP"]

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -17,7 +17,7 @@
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -17,7 +17,7 @@
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.rmsnorm import (
from ...triton_kernels.norms_common import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton
)
Expand Down
Loading