Skip to content

Fix scatter_prod GPU hang on NaN with contention#3492

Open
tillahoffmann wants to merge 3 commits intoml-explore:mainfrom
tillahoffmann:fix-scatter-prod-nan-hang
Open

Fix scatter_prod GPU hang on NaN with contention#3492
tillahoffmann wants to merge 3 commits intoml-explore:mainfrom
tillahoffmann:fix-scatter-prod-nan-hang

Conversation

@tillahoffmann
Copy link
Copy Markdown

Summary

mx.array.at[idx].multiply(val) (i.e. scatter_prod on the Metal backend) hangs forever when two or more updates target the same output position and any of those values is NaN. The hang spins inside mlx_atomic_fetch_mul_explicit in mlx/backend/metal/kernels/atomic.h and is severe enough to wedge the macOS compositor, occasionally requiring a hard reboot.

Repro (hangs without this patch):

import mlx.core as mx
mx.set_default_device(mx.gpu)
x = mx.array([1.0, 1.0, 1.0, 1.0])
out = x.at[mx.array([0, 0])].multiply(mx.array([float("nan"), 2.0]))
mx.eval(out)  # never returns

Root cause

Metal's atomic_compare_exchange_weak on float atomics performs bitwise comparison. The IEEE 754 spec leaves NaN payloads implementation-defined for arithmetic operations: on Apple Silicon, val * expected can produce a NaN bit pattern that differs from expected even when expected is itself a NaN. The proposed-new value differs bit-for-bit from expected on every retry, the CAS keeps failing, and the loop never converges.

Fix

Short-circuit when NaN is in play. nan * anything == nan is well-defined the moment NaN appears, so the CAS retry is unnecessary:

  • Outer guard: if val is NaN, store val and return.
  • Inner guard: if a retry observes expected (the in-memory value) as NaN, store it and return.

Both guards are wrapped in if constexpr (metal::is_floating_point_v<T>) so they compile out for integer template instantiations. The non-native (packed uint32) overload was deliberately left untouched: empirical testing on float16/bfloat16 shows that path does not hang (the packed CAS converges because the comparison key is the freshly-loaded packed bits, not a recomputed NaN-multiply result).

Scope verified

Empirically swept the obvious axes on the unfixed build with a 15s timeout to find what hangs:

op float32 float16 bfloat16
scatter_prod hangs passes passes
scatter_add passes passes passes
scatter_max passes passes
scatter_min passes passes
mx.prod reduction (uses same atomic) passes passes passes

Only native float32 scatter_prod was affected. The fix is targeted at exactly that path.

Test plan

  • New regression test test scatter_prod with NaN does not hang in tests/gpu_tests.cpp. Two cases (NaN-as-update; NaN-already-in-memory), each empirically confirmed to hang on the unfixed kernel and to complete in <100ms with the fix. Each call is wrapped in std::async + wait_for(10s) so a future regression surfaces as a test failure instead of a wedged runner.
  • Full GPU test suite passes (247/247 cases, 3410/3410 assertions) with DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1.
  • Python repros from the bug report complete in <1s and return [nan, 1, 1, 1].

Two or more updates targeting the same output position would spin
forever in mlx_atomic_fetch_mul_explicit when any of them was NaN.
Metal's atomic_compare_exchange_weak does bitwise comparison on the
typed-float atomic, and `val * expected` is free to choose a NaN
payload that differs from `expected`, so the CAS never converges.

Short-circuit on NaN (in `val` or in memory): the result is well
defined regardless, so just store NaN and return. Add a regression
test that wraps each scatter_prod call in std::async with a 10s
timeout so a future regression surfaces as a failure rather than
wedging the test runner.
Comment thread mlx/backend/metal/kernels/atomic.h Outdated
Comment thread mlx/backend/metal/kernels/atomic.h Outdated
Per review feedback: the inner-loop isnan(expected) guard already covers
the NaN-input case on the first failed CAS iteration. Keep only the
load-bearing guard and tighten the comment to call out contention.
Use break instead of store-and-return inside the guard — memory is
already NaN, no further write needed.
@tillahoffmann
Copy link
Copy Markdown
Author

Thank you for the review! Yes, a simple break does the trick.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 8, 2026

Your test clearly demonstrated the bug, and the fix works, but I think I'm still very confused about the root cause.

So if object != expected then atomic_compare_exchange_weak_explicit would load object into expected, and I can't think of a scenario that would cause an endless loop? Unless the Metal implementation is not doing bitwise comparison but value comparison and in which case NaN != NaN would always be true. The CUDA backend implements the op in the same way with Metal backend and does not suffer from this bug.

@tillahoffmann
Copy link
Copy Markdown
Author

That makes sense and is a better explanation than NaN bit pattern mismatch (which would explain the issue with memcmp). I've had Claude dig into this more. Here's the analysis.

Looking at the AIR (Apple's LLVM IR) for atomic_compare_exchange_weak_explicit:

%48 = call fast float @air.atomic.global.cmpxchg.weak.f32(...)
%49 = fcmp fast ueq float %42, %48

The atomic builtin returns the previous memory value as a float; success is then computed by fcmp fast ueq against expected. Metal kernels are compiled with no-nans-fp-math=true/unsafe-fp-math=true, so this lowers to ordered FP equality which returns false for NaN bits, even when the bits match.

The atomic<uint> variant emits icmp eq i32 — pure integer compare, NaN bits compare equal because they're identical i32 values, CAS converges normally.

Repro below runs the same loop both ways. Float-CAS hits the 64-iter cap; uint-CAS converges in 2.

So the bug is: Metal's compilation of atomic<float>::compare_exchange violates the C++ spec's requirement of bitwise comparison, because it computes success via a NaN-unsafe fcmp after the atomic op. The break workaround sidesteps this. A more principled fix would be to do the CAS over atomic<uint> bits directly or request an upstream fix.

Repro script
"""Probe Apple GPU NaN handling and the CAS-multiply spin.

Test 1: Does `mul * NaN_X` preserve NaN_X's bits, or canonicalize?
Test 2: Faithful reproducer of the scatter_prod scenario — 2 threads colliding
        on a single slot pre-seeded with NaN, vals = 2.0 and 3.0. Run the same
        CAS-multiply loop as mlx_atomic_fetch_mul_explicit (no guard) with an
        iteration cap, dump per-iteration bit patterns.
"""

import math
import struct
import mlx.core as mx


def fmt(b):
    f = struct.unpack("<f", struct.pack("<I", b & 0xFFFFFFFF))[0]
    label = "NaN" if math.isnan(f) else f"{f:g}"
    return f"0x{b & 0xFFFFFFFF:08x} ({label})"


# ---- Test 1: NaN payload preservation -------------------------------------
print("=" * 72)
print("TEST 1: NaN payload through `mul * NaN`")
print("=" * 72)

nan_seeds = [0x7FC00000, 0x7FC12345, 0xFFABCDEF, 0x7F800001]
multipliers = [2.0, 0.5, -1.0, 1.0001]
N = len(nan_seeds) * len(multipliers)

src1 = f"""
    uint tid = thread_position_in_grid.x;
    if (tid >= {N}u) return;
    float n = as_type<float>(seeds[tid]);
    float r = mults[tid] * n;
    atomic_store_explicit(&out_bits[tid], as_type<uint>(r), memory_order_relaxed);
"""
k1 = mx.fast.metal_kernel(
    name="nan_mul",
    input_names=["seeds", "mults"],
    output_names=["out_bits"],
    source=src1,
    atomic_outputs=True,
)
seeds = mx.array([s for s in nan_seeds for _ in multipliers], dtype=mx.uint32)
muls = mx.array([m for _ in nan_seeds for m in multipliers], dtype=mx.float32)
(out,) = k1(
    inputs=[seeds, muls],
    grid=(N, 1, 1),
    threadgroup=(min(N, 32), 1, 1),
    output_shapes=[(N,)],
    output_dtypes=[mx.uint32],
    stream=mx.gpu,
)
mx.eval(out)
op = out.tolist()
sp = seeds.tolist()
mp = muls.tolist()
for i in range(N):
    same = sp[i] == op[i]
    print(
        f"  in={fmt(sp[i])}  *  {mp[i]:<8.4g}  =  {fmt(op[i])}  "
        f"{'preserved' if same else 'CANONICALIZED'}"
    )

# ---- Test 2: 2-thread CAS spin on NaN-seeded slot -------------------------
print()
print("=" * 72)
print("TEST 2: 2 threads, mem seeded NaN, vals = 2.0 and 3.0, no NaN guard")
print("=" * 72)

NTH = 2
MAX_ITERS = 64
LOG_FIELDS = 4  # exp_before, new, cas_result, exp_after

# We need mem seeded with NaN bits before the CAS loop. atomic_outputs=True
# zero-inits the buffer; we seed in-kernel. But because both threads also read
# mem at the start, we must barrier after the seed write.

src2 = f"""
    uint tid = thread_position_in_grid.y;
    if (tid >= {NTH}u) return;

    // INTEGER CAS variant: cast the float buffer to uint, do CAS on bit
    // patterns. Should NOT exhibit the bug if it's an FP-equality issue.
    if (tid == 0u) {{
        atomic_store_explicit(&mem[0], as_type<uint>(1.0f), memory_order_relaxed);
    }}
    threadgroup_barrier(mem_flags::mem_device);

    float val = vals[tid];
    uint base = tid * {MAX_ITERS}u * {LOG_FIELDS}u;
    uint expected_bits = atomic_load_explicit(&mem[0], memory_order_relaxed);
    float expected = as_type<float>(expected_bits);

    uint i = 0u;
    bool ok = false;
    for (; i < {MAX_ITERS}u; ++i) {{
        uint row = base + i * {LOG_FIELDS}u;
        float new_value = val * expected;
        atomic_store_explicit(&log[row + 0u], expected_bits, memory_order_relaxed);
        atomic_store_explicit(&log[row + 1u], as_type<uint>(new_value), memory_order_relaxed);
        bool success = atomic_compare_exchange_weak_explicit(
            &mem[0],
            &expected_bits,
            as_type<uint>(new_value),
            memory_order_relaxed, memory_order_relaxed);
        expected = as_type<float>(expected_bits);
        // Encode the bool two different ways to detect bool weirdness.
        uint succ_a = success ? 0xAAAAAAAAu : 0xBBBBBBBBu;
        uint succ_b = (uint)success;  // 0 or 1 if standard bool
        atomic_store_explicit(&log[row + 2u], succ_a ^ succ_b, memory_order_relaxed);
        atomic_store_explicit(&log[row + 3u], as_type<uint>(expected), memory_order_relaxed);
        if (success) {{ ok = true; break; }}
    }}
    // iters = number of iterations actually executed (1-based count)
    atomic_store_explicit(&iters[tid], ok ? (i + 1u) : i, memory_order_relaxed);
"""

k2 = mx.fast.metal_kernel(
    name="nan_cas_two_thread",
    input_names=["vals"],
    output_names=["mem", "log", "iters"],
    source=src2,
    atomic_outputs=True,
)

vals = mx.array([float("nan"), 2.0], dtype=mx.float32)
mem, log, iters = k2(
    inputs=[vals],
    grid=(1, NTH, 1),  # match scatter: grid_dims = (upd_size=1, grid_y=2, 1)
    threadgroup=(1, NTH, 1),
    output_shapes=[(1,), (NTH * MAX_ITERS * LOG_FIELDS,), (NTH,)],
    output_dtypes=[mx.uint32, mx.uint32, mx.uint32],
    stream=mx.gpu,
)
mx.eval(mem, log, iters)

mp = mem.tolist()
lp = log.tolist()
ip = iters.tolist()

print(f"final mem: {fmt(mp[0])}")
print(f"iters: {ip}  (cap = {MAX_ITERS}, hit cap = spin)")
for tid in range(NTH):
    print(f"\nthread {tid} trace ({ip[tid]} iters):")
    base = tid * MAX_ITERS * LOG_FIELDS
    n = min(ip[tid], 16)
    for i in range(n):
        row = base + i * LOG_FIELDS
        eb, nv, casbits, ea = lp[row], lp[row + 1], lp[row + 2], lp[row + 3]
        print(
            f"  i={i:2d}  exp_before={fmt(eb)}  new={fmt(nv)}  "
            f"cas_xor=0x{casbits:08x}  exp_after={fmt(ea)}"
        )
    if ip[tid] > n:
        print(f"  ... ({ip[tid] - n} more)")

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 9, 2026

Can you update the comments of the code and test, to clarify it is working around the behavior of Metal implementation not doing bitwise comparison? I'm still checking with the Metal team about the details and I think this fix is good to go.

Per review feedback: the loop spins because Metal's
atomic_compare_exchange_weak_explicit<float> does not perform bitwise
comparison as the C++ atomics spec requires. The compiler lowers the
success check to fcmp fast ueq under no-nans-fp-math, which is false
when either operand is NaN even when bit patterns match. Reword the
comment to explain this directly, and trim the test comments to point
at the single explanation in atomic.h.
@tillahoffmann
Copy link
Copy Markdown
Author

I've updated the comments. Reproduction scripts for completeness below; can be run with xcrun metal -S -O0 /tmp/cas_float.metal -o /tmp/cas_float.air.

Thank you for looking into this!

.metal file illustrating non-convergence of the loop
#include <metal_stdlib>
#include <metal_atomic>
using namespace metal;

kernel void cas_float(
    device atomic<float>* mem [[buffer(0)]],
    device float* val_in [[buffer(1)]],
    device uint* out [[buffer(2)]],
    uint tid [[thread_position_in_grid]]) {
  float val = val_in[0];
  float expected = atomic_load_explicit(mem, memory_order_relaxed);
  uint i = 0;
  for (; i < 64u; ++i) {
    float new_value = val * expected;
    bool success = atomic_compare_exchange_weak_explicit(
        mem, &expected, new_value,
        memory_order_relaxed, memory_order_relaxed);
    if (success) break;
  }
  out[tid] = i;
}
Corresponding .air file with `===` emphasis
; ModuleID = '/tmp/cas_float.metal'
source_filename = "/tmp/cas_float.metal"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32"
target triple = "air64_v28-apple-macosx26.0.0"

%"struct.metal::_atomic" = type { float }

; Function Attrs: convergent mustprogress noinline nounwind optnone
define void @cas_float(%"struct.metal::_atomic" addrspace(1)* noundef "air-buffer-no-alias" %0, float addrspace(1)* noundef "air-buffer-no-alias" %1, i32 addrspace(1)* noundef "air-buffer-no-alias" %2, i32 noundef %3) #0 {
  %5 = alloca %"struct.metal::_atomic" addrspace(1)*, align 8
  %6 = alloca float*, align 8
  %7 = alloca float, align 4
  %8 = alloca i32, align 4
  %9 = alloca i32, align 4
  %10 = alloca float, align 4
  %11 = alloca i8, align 1
  %12 = alloca %"struct.metal::_atomic" addrspace(1)*, align 8
  %13 = alloca i32, align 4
  %14 = alloca %"struct.metal::_atomic" addrspace(1)*, align 8
  %15 = alloca float addrspace(1)*, align 8
  %16 = alloca i32 addrspace(1)*, align 8
  %17 = alloca i32, align 4
  %18 = alloca float, align 4
  %19 = alloca float, align 4
  %20 = alloca i32, align 4
  %21 = alloca float, align 4
  %22 = alloca i8, align 1
  store %"struct.metal::_atomic" addrspace(1)* %0, %"struct.metal::_atomic" addrspace(1)** %14, align 8
  store float addrspace(1)* %1, float addrspace(1)** %15, align 8
  store i32 addrspace(1)* %2, i32 addrspace(1)** %16, align 8
  store i32 %3, i32* %17, align 4
  %23 = load float addrspace(1)*, float addrspace(1)** %15, align 8
  %24 = getelementptr inbounds float, float addrspace(1)* %23, i64 0
  %25 = load float, float addrspace(1)* %24, align 4
  store float %25, float* %18, align 4
  %26 = load %"struct.metal::_atomic" addrspace(1)*, %"struct.metal::_atomic" addrspace(1)** %14, align 8
  store %"struct.metal::_atomic" addrspace(1)* %26, %"struct.metal::_atomic" addrspace(1)** %12, align 8
  store i32 0, i32* %13, align 4
  %27 = load %"struct.metal::_atomic" addrspace(1)*, %"struct.metal::_atomic" addrspace(1)** %12, align 8
  %28 = getelementptr inbounds %"struct.metal::_atomic", %"struct.metal::_atomic" addrspace(1)* %27, i32 0, i32 0
  %29 = load i32, i32* %13, align 4
  %30 = call fast float @air.atomic.global.load.f32(float addrspace(1)* nocapture %28, i32 %29, i32 2, i1 true) #1
  store float %30, float* %19, align 4
  store i32 0, i32* %20, align 4
  br label %31

31:                                               ; preds = %60, %4
  %32 = load i32, i32* %20, align 4
  %33 = icmp ult i32 %32, 64
  br i1 %33, label %34, label %63

34:                                               ; preds = %31
  %35 = load float, float* %18, align 4
  %36 = load float, float* %19, align 4
  %37 = fmul fast float %35, %36
  store float %37, float* %21, align 4
  %38 = load %"struct.metal::_atomic" addrspace(1)*, %"struct.metal::_atomic" addrspace(1)** %14, align 8
  %39 = load float, float* %21, align 4
  store %"struct.metal::_atomic" addrspace(1)* %38, %"struct.metal::_atomic" addrspace(1)** %5, align 8
  store float* %19, float** %6, align 8
  store float %39, float* %7, align 4
  store i32 0, i32* %8, align 4
  store i32 0, i32* %9, align 4
  %40 = load float*, float** %6, align 8
  %41 = load float, float* %40, align 4
  store float %41, float* %10, align 4
  %42 = load float, float* %10, align 4
  %43 = load %"struct.metal::_atomic" addrspace(1)*, %"struct.metal::_atomic" addrspace(1)** %5, align 8
  %44 = getelementptr inbounds %"struct.metal::_atomic", %"struct.metal::_atomic" addrspace(1)* %43, i32 0, i32 0
  %45 = load float, float* %7, align 4
  %46 = load i32, i32* %8, align 4
  %47 = load i32, i32* %9, align 4
================================================================================
  %48 = call fast float @air.atomic.global.cmpxchg.weak.f32(float addrspace(1)* nocapture %44, float* nocapture %10, float %45, i32 %46, i32 %47, i32 2, i1 true) #1
  %49 = fcmp fast ueq float %42, %48
================================================================================
  %50 = zext i1 %49 to i8
  store i8 %50, i8* %11, align 1
  %51 = load float, float* %10, align 4
  %52 = load float*, float** %6, align 8
  store float %51, float* %52, align 4
  %53 = load i8, i8* %11, align 1
  %54 = trunc i8 %53 to i1
  %55 = zext i1 %54 to i8
  store i8 %55, i8* %22, align 1
  %56 = load i8, i8* %22, align 1
  %57 = trunc i8 %56 to i1
  br i1 %57, label %58, label %59

58:                                               ; preds = %34
  br label %63

59:                                               ; preds = %34
  br label %60

60:                                               ; preds = %59
  %61 = load i32, i32* %20, align 4
  %62 = add i32 %61, 1
  store i32 %62, i32* %20, align 4
  br label %31, !llvm.loop !24

63:                                               ; preds = %58, %31
  %64 = load i32, i32* %20, align 4
  %65 = load i32 addrspace(1)*, i32 addrspace(1)** %16, align 8
  %66 = load i32, i32* %17, align 4
  %67 = zext i32 %66 to i64
  %68 = getelementptr inbounds i32, i32 addrspace(1)* %65, i64 %67
  store i32 %64, i32 addrspace(1)* %68, align 4
  ret void
}

; Function Attrs: nounwind willreturn
declare float @air.atomic.global.load.f32(float addrspace(1)* nocapture, i32, i32, i1) #1

; Function Attrs: nounwind willreturn
declare float @air.atomic.global.cmpxchg.weak.f32(float addrspace(1)* nocapture, float* nocapture, float, i32, i32, i32, i1) #1

attributes #0 = { convergent mustprogress noinline nounwind optnone "approx-func-fp-math"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="true" }
attributes #1 = { nounwind willreturn }

!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8}
!air.kernel = !{!9}
!air.compile_options = !{!17, !18, !19}
!llvm.ident = !{!20}
!air.version = !{!21}
!air.language_version = !{!22}
!air.source_file_name = !{!23}

!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 26, i32 4]}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{i32 7, !"frame-pointer", i32 2}
!3 = !{i32 7, !"air.max_device_buffers", i32 31}
!4 = !{i32 7, !"air.max_constant_buffers", i32 31}
!5 = !{i32 7, !"air.max_threadgroup_buffers", i32 31}
!6 = !{i32 7, !"air.max_textures", i32 128}
!7 = !{i32 7, !"air.max_read_write_textures", i32 8}
!8 = !{i32 7, !"air.max_samplers", i32 16}
!9 = !{void (%"struct.metal::_atomic" addrspace(1)*, float addrspace(1)*, i32 addrspace(1)*, i32)* @cas_float, !10, !11}
!10 = !{}
!11 = !{!12, !14, !15, !16}
!12 = !{i32 0, !"air.buffer", !"air.location_index", i32 0, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.struct_type_info", !13, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"metal::_atomic", !"air.arg_name", !"mem"}
!13 = !{i32 0, i32 4, i32 0, !"float", !"__s"}
!14 = !{i32 1, !"air.buffer", !"air.location_index", i32 1, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"float", !"air.arg_name", !"val_in"}
!15 = !{i32 2, !"air.buffer", !"air.location_index", i32 2, i32 1, !"air.read_write", !"air.address_space", i32 1, !"air.arg_type_size", i32 4, !"air.arg_type_align_size", i32 4, !"air.arg_type_name", !"uint", !"air.arg_name", !"out"}
!16 = !{i32 3, !"air.thread_position_in_grid", !"air.arg_type_name", !"uint", !"air.arg_name", !"tid"}
!17 = !{!"air.compile.denorms_disable"}
!18 = !{!"air.compile.fast_math_enable"}
!19 = !{!"air.compile.framebuffer_fetch_enable"}
!20 = !{!"Apple metal version 32023.830 (metalfe-32023.830.2)"}
!21 = !{i32 2, i32 8, i32 0}
!22 = !{!"Metal", i32 4, i32 0, i32 0}
!23 = !{!"/private/tmp/cas_float.metal"}
!24 = distinct !{!24, !25}
!25 = !{!"llvm.loop.mustprogress"}

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants