Fix scatter_prod GPU hang on NaN with contention#3492
Fix scatter_prod GPU hang on NaN with contention#3492tillahoffmann wants to merge 3 commits intoml-explore:mainfrom
Conversation
Two or more updates targeting the same output position would spin forever in mlx_atomic_fetch_mul_explicit when any of them was NaN. Metal's atomic_compare_exchange_weak does bitwise comparison on the typed-float atomic, and `val * expected` is free to choose a NaN payload that differs from `expected`, so the CAS never converges. Short-circuit on NaN (in `val` or in memory): the result is well defined regardless, so just store NaN and return. Add a regression test that wraps each scatter_prod call in std::async with a 10s timeout so a future regression surfaces as a failure rather than wedging the test runner.
Per review feedback: the inner-loop isnan(expected) guard already covers the NaN-input case on the first failed CAS iteration. Keep only the load-bearing guard and tighten the comment to call out contention. Use break instead of store-and-return inside the guard — memory is already NaN, no further write needed.
|
Thank you for the review! Yes, a simple |
|
Your test clearly demonstrated the bug, and the fix works, but I think I'm still very confused about the root cause. So if |
|
That makes sense and is a better explanation than NaN bit pattern mismatch (which would explain the issue with memcmp). I've had Claude dig into this more. Here's the analysis.
Repro script"""Probe Apple GPU NaN handling and the CAS-multiply spin.
Test 1: Does `mul * NaN_X` preserve NaN_X's bits, or canonicalize?
Test 2: Faithful reproducer of the scatter_prod scenario — 2 threads colliding
on a single slot pre-seeded with NaN, vals = 2.0 and 3.0. Run the same
CAS-multiply loop as mlx_atomic_fetch_mul_explicit (no guard) with an
iteration cap, dump per-iteration bit patterns.
"""
import math
import struct
import mlx.core as mx
def fmt(b):
f = struct.unpack("<f", struct.pack("<I", b & 0xFFFFFFFF))[0]
label = "NaN" if math.isnan(f) else f"{f:g}"
return f"0x{b & 0xFFFFFFFF:08x} ({label})"
# ---- Test 1: NaN payload preservation -------------------------------------
print("=" * 72)
print("TEST 1: NaN payload through `mul * NaN`")
print("=" * 72)
nan_seeds = [0x7FC00000, 0x7FC12345, 0xFFABCDEF, 0x7F800001]
multipliers = [2.0, 0.5, -1.0, 1.0001]
N = len(nan_seeds) * len(multipliers)
src1 = f"""
uint tid = thread_position_in_grid.x;
if (tid >= {N}u) return;
float n = as_type<float>(seeds[tid]);
float r = mults[tid] * n;
atomic_store_explicit(&out_bits[tid], as_type<uint>(r), memory_order_relaxed);
"""
k1 = mx.fast.metal_kernel(
name="nan_mul",
input_names=["seeds", "mults"],
output_names=["out_bits"],
source=src1,
atomic_outputs=True,
)
seeds = mx.array([s for s in nan_seeds for _ in multipliers], dtype=mx.uint32)
muls = mx.array([m for _ in nan_seeds for m in multipliers], dtype=mx.float32)
(out,) = k1(
inputs=[seeds, muls],
grid=(N, 1, 1),
threadgroup=(min(N, 32), 1, 1),
output_shapes=[(N,)],
output_dtypes=[mx.uint32],
stream=mx.gpu,
)
mx.eval(out)
op = out.tolist()
sp = seeds.tolist()
mp = muls.tolist()
for i in range(N):
same = sp[i] == op[i]
print(
f" in={fmt(sp[i])} * {mp[i]:<8.4g} = {fmt(op[i])} "
f"{'preserved' if same else 'CANONICALIZED'}"
)
# ---- Test 2: 2-thread CAS spin on NaN-seeded slot -------------------------
print()
print("=" * 72)
print("TEST 2: 2 threads, mem seeded NaN, vals = 2.0 and 3.0, no NaN guard")
print("=" * 72)
NTH = 2
MAX_ITERS = 64
LOG_FIELDS = 4 # exp_before, new, cas_result, exp_after
# We need mem seeded with NaN bits before the CAS loop. atomic_outputs=True
# zero-inits the buffer; we seed in-kernel. But because both threads also read
# mem at the start, we must barrier after the seed write.
src2 = f"""
uint tid = thread_position_in_grid.y;
if (tid >= {NTH}u) return;
// INTEGER CAS variant: cast the float buffer to uint, do CAS on bit
// patterns. Should NOT exhibit the bug if it's an FP-equality issue.
if (tid == 0u) {{
atomic_store_explicit(&mem[0], as_type<uint>(1.0f), memory_order_relaxed);
}}
threadgroup_barrier(mem_flags::mem_device);
float val = vals[tid];
uint base = tid * {MAX_ITERS}u * {LOG_FIELDS}u;
uint expected_bits = atomic_load_explicit(&mem[0], memory_order_relaxed);
float expected = as_type<float>(expected_bits);
uint i = 0u;
bool ok = false;
for (; i < {MAX_ITERS}u; ++i) {{
uint row = base + i * {LOG_FIELDS}u;
float new_value = val * expected;
atomic_store_explicit(&log[row + 0u], expected_bits, memory_order_relaxed);
atomic_store_explicit(&log[row + 1u], as_type<uint>(new_value), memory_order_relaxed);
bool success = atomic_compare_exchange_weak_explicit(
&mem[0],
&expected_bits,
as_type<uint>(new_value),
memory_order_relaxed, memory_order_relaxed);
expected = as_type<float>(expected_bits);
// Encode the bool two different ways to detect bool weirdness.
uint succ_a = success ? 0xAAAAAAAAu : 0xBBBBBBBBu;
uint succ_b = (uint)success; // 0 or 1 if standard bool
atomic_store_explicit(&log[row + 2u], succ_a ^ succ_b, memory_order_relaxed);
atomic_store_explicit(&log[row + 3u], as_type<uint>(expected), memory_order_relaxed);
if (success) {{ ok = true; break; }}
}}
// iters = number of iterations actually executed (1-based count)
atomic_store_explicit(&iters[tid], ok ? (i + 1u) : i, memory_order_relaxed);
"""
k2 = mx.fast.metal_kernel(
name="nan_cas_two_thread",
input_names=["vals"],
output_names=["mem", "log", "iters"],
source=src2,
atomic_outputs=True,
)
vals = mx.array([float("nan"), 2.0], dtype=mx.float32)
mem, log, iters = k2(
inputs=[vals],
grid=(1, NTH, 1), # match scatter: grid_dims = (upd_size=1, grid_y=2, 1)
threadgroup=(1, NTH, 1),
output_shapes=[(1,), (NTH * MAX_ITERS * LOG_FIELDS,), (NTH,)],
output_dtypes=[mx.uint32, mx.uint32, mx.uint32],
stream=mx.gpu,
)
mx.eval(mem, log, iters)
mp = mem.tolist()
lp = log.tolist()
ip = iters.tolist()
print(f"final mem: {fmt(mp[0])}")
print(f"iters: {ip} (cap = {MAX_ITERS}, hit cap = spin)")
for tid in range(NTH):
print(f"\nthread {tid} trace ({ip[tid]} iters):")
base = tid * MAX_ITERS * LOG_FIELDS
n = min(ip[tid], 16)
for i in range(n):
row = base + i * LOG_FIELDS
eb, nv, casbits, ea = lp[row], lp[row + 1], lp[row + 2], lp[row + 3]
print(
f" i={i:2d} exp_before={fmt(eb)} new={fmt(nv)} "
f"cas_xor=0x{casbits:08x} exp_after={fmt(ea)}"
)
if ip[tid] > n:
print(f" ... ({ip[tid] - n} more)") |
|
Can you update the comments of the code and test, to clarify it is working around the behavior of Metal implementation not doing bitwise comparison? I'm still checking with the Metal team about the details and I think this fix is good to go. |
Per review feedback: the loop spins because Metal's atomic_compare_exchange_weak_explicit<float> does not perform bitwise comparison as the C++ atomics spec requires. The compiler lowers the success check to fcmp fast ueq under no-nans-fp-math, which is false when either operand is NaN even when bit patterns match. Reword the comment to explain this directly, and trim the test comments to point at the single explanation in atomic.h.
|
I've updated the comments. Reproduction scripts for completeness below; can be run with Thank you for looking into this! .metal file illustrating non-convergence of the loop#include <metal_stdlib>
#include <metal_atomic>
using namespace metal;
kernel void cas_float(
device atomic<float>* mem [[buffer(0)]],
device float* val_in [[buffer(1)]],
device uint* out [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
float val = val_in[0];
float expected = atomic_load_explicit(mem, memory_order_relaxed);
uint i = 0;
for (; i < 64u; ++i) {
float new_value = val * expected;
bool success = atomic_compare_exchange_weak_explicit(
mem, &expected, new_value,
memory_order_relaxed, memory_order_relaxed);
if (success) break;
}
out[tid] = i;
}Corresponding .air file with `===` emphasis |
Summary
mx.array.at[idx].multiply(val)(i.e.scatter_prodon the Metal backend) hangs forever when two or more updates target the same output position and any of those values is NaN. The hang spins insidemlx_atomic_fetch_mul_explicitinmlx/backend/metal/kernels/atomic.hand is severe enough to wedge the macOS compositor, occasionally requiring a hard reboot.Repro (hangs without this patch):
Root cause
Metal's
atomic_compare_exchange_weakon float atomics performs bitwise comparison. The IEEE 754 spec leaves NaN payloads implementation-defined for arithmetic operations: on Apple Silicon,val * expectedcan produce a NaN bit pattern that differs fromexpectedeven whenexpectedis itself a NaN. The proposed-new value differs bit-for-bit fromexpectedon every retry, the CAS keeps failing, and the loop never converges.Fix
Short-circuit when NaN is in play.
nan * anything == nanis well-defined the moment NaN appears, so the CAS retry is unnecessary:valis NaN, storevaland return.expected(the in-memory value) as NaN, store it and return.Both guards are wrapped in
if constexpr (metal::is_floating_point_v<T>)so they compile out for integer template instantiations. The non-native (packeduint32) overload was deliberately left untouched: empirical testing onfloat16/bfloat16shows that path does not hang (the packed CAS converges because the comparison key is the freshly-loaded packed bits, not a recomputed NaN-multiply result).Scope verified
Empirically swept the obvious axes on the unfixed build with a 15s timeout to find what hangs:
scatter_prodscatter_addscatter_maxscatter_minmx.prodreduction (uses same atomic)Only native
float32scatter_prodwas affected. The fix is targeted at exactly that path.Test plan
test scatter_prod with NaN does not hangintests/gpu_tests.cpp. Two cases (NaN-as-update; NaN-already-in-memory), each empirically confirmed to hang on the unfixed kernel and to complete in <100ms with the fix. Each call is wrapped instd::async+wait_for(10s)so a future regression surfaces as a test failure instead of a wedged runner.DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1.[nan, 1, 1, 1].