Skip to content

NVFP4: Work around intermittent incorrect results for backward GEMMs#580

Open
matthiasdiener wants to merge 5 commits intodevfrom
mdiener/nvfp4-backward-workaround
Open

NVFP4: Work around intermittent incorrect results for backward GEMMs#580
matthiasdiener wants to merge 5 commits intodevfrom
mdiener/nvfp4-backward-workaround

Conversation

@matthiasdiener
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener commented May 7, 2026

Description

Main failing test: torchrun --nproc_per_node=4 /dockerx/TransformerEngine/tests/pytorch/distributed/run_numerics.py --quantization nvfp4

Observed on gfx942 and gfx950.

Smaller reproducer
import sys
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling, QParams

def nvfp4_vanilla():
    r = NVFP4BlockScaling()
    r.fp4_quant_fwd_inp = QParams()
    r.fp4_quant_fwd_weight = QParams()
    r.fp4_quant_bwd_grad = QParams()
    return r

def check_output(name, out):
    finite_mask = torch.isfinite(out)
    has_nonfinite = not finite_mask.all().item()
    if finite_mask.any().item():
        absmax = out[finite_mask].abs().max().item()
    else:
        absmax = float("nan")
    garbage = has_nonfinite or absmax > 1000
    info = f"  {name}: absmax={absmax:.6g}"
    if garbage:
        mask = (~finite_mask) | (out.abs() > 1000)
        cols = mask.any(dim=0)
        first = cols.nonzero()[0].item() if cols.any() else -1
        last = cols.nonzero()[-1].item() if cols.any() else -1
        info += f" GARBAGE cols=[{first},{last}]"
        if has_nonfinite:
            info += f" nonfinite={(~finite_mask).sum().item()}"
    else:
        info += " OK"
    print(info)
    return garbage


def run_test(params_dtype, label):
    print(f"\n{'='*60}")
    print(f"Test: {label} (params_dtype={params_dtype})")
    print(f"{'='*60}")

    device = torch.device("cuda:0")
    H = 128
    B = 32
    recipe = nvfp4_vanilla()

    # Create model used in "Test 4" (forward+backward)
    model_fwdbwd = te.Linear(H, H, params_dtype=params_dtype).to(device)

    # Step 1: Forward only (should pass)
    inp1 = torch.randn((B, H), device=device, dtype=params_dtype)
    with te.autocast(enabled=True, recipe=recipe):
        out1 = model_fwdbwd(inp1)
    torch.cuda.synchronize()
    g1 = check_output("Step1 fwd-only", out1)

    # Step 2: Forward + Backward (this is what triggers the bug)
    inp2 = torch.randn((B, H), device=device, dtype=params_dtype, requires_grad=True)
    with te.autocast(enabled=True, recipe=recipe):
        out2 = model_fwdbwd(inp2)
    loss = out2.sum()
    loss.backward()
    torch.cuda.synchronize()
    g2 = check_output("Step2 fwd+bwd (fwd out)", out2)

    # Step 3: Forward again with a NEW model (this is what fails)
    model_new = te.Linear(H, H, params_dtype=params_dtype).to(device)
    inp3 = torch.randn((B, H), device=device, dtype=params_dtype)
    with te.autocast(enabled=True, recipe=recipe):
        out3 = model_new(inp3)
    torch.cuda.synchronize()
    g3 = check_output("Step3 new-model fwd", out3)

    # Step 4: Forward with a SMALL model, then original model again
    model_small = te.Linear(H, H // 4, params_dtype=params_dtype).to(device)
    inp4a = torch.randn((B, H), device=device, dtype=params_dtype)
    inp4b = torch.randn((B, H), device=device, dtype=params_dtype)
    with te.autocast(enabled=True, recipe=recipe):
        out4a = model_new(inp4b)  # big model again
    torch.cuda.synchronize()
    g4a = check_output("Step4 big-model fwd (post-backward)", out4a)

    # Step 5: Run small model then big model (simulates distributed test)
    model_big2 = te.Linear(H, H, params_dtype=params_dtype).to(device)
    model_sm2 = te.Linear(H, H // 4, params_dtype=params_dtype).to(device)
    inp5 = torch.randn((B, H), device=device, dtype=params_dtype)
    with te.autocast(enabled=True, recipe=recipe):
        out5 = model_big2(inp5)
    torch.cuda.synchronize()
    g5 = check_output("Step5 fresh-big-model fwd", out5)

    any_fail = g1 or g2 or g3 or g4a or g5
    print(f"  RESULT: {'FAIL' if any_fail else 'PASS'}")
    return any_fail


def run_test_no_backward(params_dtype, label):
    """Same as run_test but skip the backward to confirm it's the trigger."""
    print(f"\n{'='*60}")
    print(f"Test: {label} (no backward, params_dtype={params_dtype})")
    print(f"{'='*60}")

    device = torch.device("cuda:0")
    H = 128
    B = 32
    recipe = nvfp4_vanilla()

    model = te.Linear(H, H, params_dtype=params_dtype).to(device)

    # Two forward passes (no backward)
    for i in range(3):
        inp = torch.randn((B, H), device=device, dtype=params_dtype)
        with te.autocast(enabled=True, recipe=recipe):
            out = model(inp)
        torch.cuda.synchronize()
        g = check_output(f"Fwd-only pass {i+1}", out)
        if g:
            print(f"  RESULT: FAIL")
            return True

    # Now a new model
    model2 = te.Linear(H, H, params_dtype=params_dtype).to(device)
    inp = torch.randn((B, H), device=device, dtype=params_dtype)
    with te.autocast(enabled=True, recipe=recipe):
        out = model2(inp)
    torch.cuda.synchronize()
    g = check_output("New model after 3 fwds", out)
    print(f"  RESULT: {'FAIL' if g else 'PASS'}")
    return g


def main():
    print(f"Device: {torch.cuda.get_device_name(0)}")

    results = {}

    # Test 1: FP32 output, no backward (should pass)
    results['fp32_no_bwd'] = run_test_no_backward(torch.float32, "FP32 no backward")

    # Test 2: FP32 output, with backward (expected to fail)
    results['fp32_with_bwd'] = run_test(torch.float32, "FP32 with backward")

    # Test 3: BF16 output, with backward (should pass)
    results['bf16_with_bwd'] = run_test(torch.bfloat16, "BF16 with backward")

    # Test 4: BF16 output, no backward (should pass)
    results['bf16_no_bwd'] = run_test_no_backward(torch.bfloat16, "BF16 no backward")

    print(f"\n{'='*60}")
    print("SUMMARY:")
    for k, v in results.items():
        print(f"  {k}: {'FAIL' if v else 'PASS'}")

    any_fail = any(results.values())
    print(f"\nOverall: {'FAIL' if any_fail else 'PASS'}")
    return 1 if any_fail else 0


if __name__ == "__main__":
    sys.exit(main())

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this May 7, 2026
@matthiasdiener matthiasdiener added the ci-level 3 CI test level 3 label May 7, 2026
# FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR produce
# incorrect results intermittently on AMDGPU. Skip backward-containing sub-tests for
# nvfp4.
if IS_HIP_EXTENSION and QUANTIZATION == "nvfp4":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a ticket for that?
nit: keep original test_dict and add conditional override after that so it will be easier to remove the w/a in the future and less merging conflict

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should gpu family be checked here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a ticket for that?

I did not create one, I found it impossible to reproduce this issue without TE. Since the current implementation with the bf16 Dequant+GEMM is only a stop-gap until we have nvfp4 support in hipblaslt, it may not be necessary to create a ticket?

nit: keep original test_dict and add conditional override after that so it will be easier to remove the w/a in the future and less merging conflict

The workaround has been restructured in b609614.

Also, should gpu family be checked here?

No, I overserved the same issue on gfx942 and gfx950.

@matthiasdiener matthiasdiener marked this pull request as ready for review May 8, 2026 20:52
@matthiasdiener matthiasdiener requested a review from ipanfilo May 8, 2026 20:52
@matthiasdiener matthiasdiener changed the title NVFP4: Workaround intermittent incorrect results for backward GEMMs NVFP4: Work around intermittent incorrect results for backward GEMMs May 8, 2026
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
# TODO(zhongboz): investigate why the tolerance is so large
if IS_HIP_EXTENSION:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this nv upstream todo to their tolerance return statement. Otherwise it looks like zhongboz added this IS_HIP_EXTENSION branch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, moved in 6cbc4dc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants