Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def _get_tolerances(dtype):
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
if IS_HIP_EXTENSION:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this nv upstream todo to their tolerance return statement. Otherwise it looks like zhongboz added this IS_HIP_EXTENSION branch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, moved in 6cbc4dc

# Higher tolerance for AMDGPU to account for intermediate bf16 step in GEMM
return {"rtol": 0.125, "atol": 0.15}
# TODO(zhongboz): investigate why the tolerance is so large
return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None:
Expand Down
61 changes: 61 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,46 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._transpose.device.index
return torch.cuda.current_device()


if IS_HIP_EXTENSION:
def _should_use_bf16_output_for_nvfp4_tn(
A,
B,
layout: str,
out_dtype: Optional[torch.dtype],
out,
bias,
quantization_params,
debug_quantizer,
grad: bool,
accumulate: bool,
ub,
extra_output,
gelu: bool,
) -> bool:
"""Work around ROCm NVFP4 TN GEMM corruption when requesting FP32 output.

FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR
produce incorrect results intermittently on AMDGPU. Return True for the
narrow path where we force BF16 output, which empirically covers the
corruption cases.
"""
return (
layout == "TN"
and out_dtype == torch.float32
and out is None
and bias is not None
and quantization_params is None
and debug_quantizer is None
and not grad
and not accumulate
and ub is None
and extra_output is None
and not gelu
and (isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage))
)


def _select_kernel_fp4(layout: str, grad: bool, M: int, N: int, K: int):
"""Select kernel via tuned CSV lookup, falling back to AITER heuristic."""
from aiter.ops.gemm_op_a4w4 import get_GEMM_config
Expand Down Expand Up @@ -371,6 +411,24 @@ def general_gemm(
# FP8 block-scaling requires split accumulator
use_split_accumulator = True

if IS_HIP_EXTENSION:
use_bf16_tn_output_workaround = _should_use_bf16_output_for_nvfp4_tn(
A,
B,
layout,
out_dtype,
out,
bias,
quantization_params,
debug_quantizer,
grad,
accumulate,
ub,
extra_output,
gelu,
)
out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype

args = (
A,
transa, # transa
Expand Down Expand Up @@ -400,6 +458,9 @@ def general_gemm(

out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)

if IS_HIP_EXTENSION and use_bf16_tn_output_workaround:
out = cast_if_needed(out, torch.float32)

if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)

Expand Down
Loading