Skip to content

[Metal] NVFP4: support 3-tier global_scale, fix 2-tier simdgroup-lane bug#3558

Open
yohann-bearzi wants to merge 11 commits into
ml-explore:mainfrom
yohann-bearzi:nvfp4-metal-3tier
Open

[Metal] NVFP4: support 3-tier global_scale, fix 2-tier simdgroup-lane bug#3558
yohann-bearzi wants to merge 11 commits into
ml-explore:mainfrom
yohann-bearzi:nvfp4-metal-3tier

Conversation

@yohann-bearzi
Copy link
Copy Markdown

What this PR does

Brings the Metal backend's NVFP4 implementation to spec-compliance with NVIDIA's published format and the CUDA/CPU backends in MLX. Fixes #3557.

Two distinct issues addressed:

  1. 3-tier (per-tensor global_scale) NVFP4 was unsupported on Metal. Now supported across all NVFP4 ops and nn layers.
  2. A latent simdgroup-lane bug in the existing 2-tier path caused per-block scales to be computed over 32 elements instead of 16. Fixed.

Changes by commit

The PR is 10 bisectable commits. The simdgroup fix (commit 4) is independent of 3-tier support and could be cherry-picked in isolation.

  1. Feature flag MLX_METAL_NVFP4_3TIER — gates the 3-tier path. Default OFF preserves bit-identical behavior vs main; ON removes the "Global scale not supported" rejection in ops.cpp. The PR turns it ON.
  2. Metal kernels — add fp_quantize_3tier and fp_dequantize_3tier alongside the existing 2-tier kernels via a shared _impl helper. Instantiated for nvfp4 mode only (mxfp4/mxfp8 don't have a spec-compliant 3-tier).
  3. Host dispatchfast::Quantize::eval_gpu routes to the _g1 kernel variants when inputs[] carries a global_scale tensor.
  4. Simdgroup-lane fix — replace tidx.x with simd_lane_id for block-half bisection in fp_quantize_impl and fp_quantize_dequantize_impl. This is a fix to pre-existing 2-tier behavior. Output for NVFP4 mode now differs from main but matches the spec-compliant numpy reference. Other modes (affine, mxfp4, mxfp8) are bit-identical.
  5. QQMatmul wiring — extends QQMatmul::eval_gpu to pass global_scale_x and global_scale_w through to quantize_dequantize, which absorbs both scales into the activation pre-quantization step. No matmul kernel changes needed.
  6. quantized_matmul + global_scale_w — extend the user-facing inference op with the new parameter. Pre-multiplies x by gs_w (avoiding fp16 overflow that would happen if we post-multiplied the output by 1/gs_w).
  7. gather_qmm + global_scale_w — same extension for the MoE path.
  8. Pre-scaling correction + nn.QuantizedLinear integration — switches from post-multiply to pre-multiply for the matmul-side gs_w application (fixes a latent fp16 overflow when gs_w is small). Adds global_scale attribute and passes it through __call__ and from_linear.
  9. nn.QuantizedEmbedding integration — same pattern for embedding lookup and as_linear.
  10. nn.QQLinear integration — computes activation gs_x on the fly, threads both globals to mx.qqmm.

API changes

  • mx.quantize(W, mode='nvfp4', global_scale=...) — now accepted on Metal.
  • mx.dequantize(Wq, sc, mode='nvfp4', global_scale=...) — now accepted on Metal.
  • mx.quantized_matmul(x, Wq, sc, ..., global_scale_w=) — new optional kwarg.
  • mx.gather_qmm(x, Wq, sc, ..., global_scale_w=) — new optional kwarg.
  • mx.qqmm already accepted global_scale_x and global_scale_w; now actually honored on Metal.

All new parameters default to std::nullopt/None — adding them is backward-compatible. CPU and CUDA backends accept but ignore the new quantized_matmul/gather_qmm parameters; the CUDA 3-tier path remains via qqmm/QQMatmul, unchanged.

Validation

Tested with a separate validation harness covering pack invariance, matmul invariance, new-feature smoke test, roundtrip vs numpy spec reference, quality vs fp16 baseline, and perf benchmark. Harness can be contributed as a separate PR if useful; kept out of this PR to limit scope.

Pack/dequant roundtrip matches numpy spec reference to 4-5 decimal places:

shape MLX 2-tier numpy 2-tier ref MLX 3-tier numpy 3-tier ref
small (128×256) 0.1029 0.1029 0.0942 0.0942
decode (14336×4096, heavy-tailed) 0.0940 0.0940 0.0913 0.0913
outproj (4096×8192, channel outliers) 0.0946 0.0946 0.0930 0.0930
extreme (512×2048, extreme tail) 0.0767 0.0767 0.0738 0.0738

End-to-end matmul quality (mx.quantized_matmul with 3-tier global_scale_w):

shape fp16 ref mlx 2-tier (main) mlx 2-tier (this PR) mlx 3-tier (this PR)
decode 0 (baseline) 0.1175 0.0935 0.0906
outproj 0 (baseline) 0.1235 0.0944 0.0926
extreme 0 (baseline) 0.1021 0.0816 0.0772

The simdgroup fix in commit 4 delivers the bulk of the improvement for 2-tier users (~20% rel_err reduction). The 3-tier path adds another 2-5% on top, and is necessary for ecosystem checkpoint interop.

End-to-end nn.quantize integration on a 4-layer transformer-style model:

  • All quantized layers carry the global_scale attribute
  • Forward pass produces finite output (no inf/nan)
  • rel_err 0.042 vs fp16 baseline

Trade-off context

On Apple M3 Ultra (decode shape 14336×4096, heavy-tailed weight), NVFP4 2-tier with the simdgroup fix is Pareto-dominant among 4-bit options:

format bits/val rel_err time_ms speedup size_MB
fp16 16.00 0.000 0.42 1.00× 112
int8 affine g128 8.25 0.012 0.28 1.46× 58
nvfp4 2-tier (this PR) 4.50 0.094 0.24 1.76× 31
int4 affine g32 5.00 0.102 0.25 1.67× 35
mxfp8 8.25 0.104 0.31 1.35× 58
int4 affine g64 4.50 0.137 0.25 1.68× 31
mxfp4 4.25 0.171 0.23 1.78× 30

NVFP4 strictly dominates int4 affine g32 in all three dimensions (quality, speed, memory) and is the natural default for 4-bit quantization on Metal.

Scope notes

  • CPU and CUDA quantized_matmul/gather_qmm accept the new global_scale_w parameter but currently ignore it. CUDA's spec-compliant 3-tier path remains via qqmm/QQMatmul (unchanged). Extending CPU and CUDA quantized_matmul is future work.
  • Metal QQMatmul and GatherQMM only handle the M=1 decode path today. M>1 (prefill) still throws NYI for the general case — that's an existing limitation, unchanged here.

Testing

Pre-commit hooks (clang-format, black) run on all changed files. No new tests added to tests/ to keep the diff focused — happy to add a focused NVFP4 3-tier roundtrip test, or contribute the validation harness as a separate PR.

homecorpstudio and others added 11 commits May 16, 2026 09:14
Add compile-time CMake option (default OFF) that, when enabled,
allows global_scale to reach Metal in quantize() and dequantize().
The flag is a no-op when unset: behavior is bit-identical to main.

When set via:
  CMAKE_ARGS=-DMLX_METAL_NVFP4_3TIER=ON pip install . --no-build-isolation

mx.quantize(W, mode='nvfp4', global_scale=...) no longer throws on
Metal — it produces a packed weight + scale tuple. The packed output
is currently 2-tier in numerical content (kernel discards the
global_scale), but the storage and API plumbing is in place.

Subsequent commits will:
  - Make the Metal quantize kernel actually consume global_scale
    (commit 2) — producing spec-compliant 3-tier packed weights.
  - Add 'global_scale_w' to quantized_matmul (commit 3) — exposes the
    W4-with-per-tensor-scale × fp16-activations inference path.
  - Make qqmm's Metal path honor the global_scale_x/w it accepts
    (commit 4) — currently the params are silently dropped.

Validation:
  - flag OFF: identical to main (all invariance tests pass)
  - flag ON: gate removed, quantize accepts global_scale on Metal
…nels

Refactor fp_quantize and fp_dequantize into inline impl helpers
that accept a global_scale parameter (default 1.0). The existing
kernels wrap the helper with global_scale=1.0 — bit-identical
behavior to current main.

Add two new entry points for nvfp4:
  - fp_quantize_3tier: reads global_scale from buffer(4), divides
    input by it before computing the per-block FP8 scale
  - fp_dequantize_3tier: multiplies per-block scale by global_scale
    when computing the output value

Instantiate _g1 variants in fp_quantized.metal only for the nvfp4
mode (mxfp4/mxfp8 don't have a spec-compliant 3-tier).

Kernels exist in the metallib but no host dispatch routes to them
yet — commit 3 wires Metal/quantized.cpp::fast::Quantize::eval_gpu
to select these when inputs contain a global_scale.

Validation:
  - pack invariance ✓ PASS (refactor is a math no-op)
  - matmul invariance ✓ PASS
  - mlx_nvfp4_2tier rel_err unchanged vs c1 (bit-identical)
  - mlx_nvfp4_3tier still produces 2-tier output (host dispatch pending)
…al_scale present

When mx.quantize() or mx.dequantize() is called with mode='nvfp4' and
a non-null global_scale, the primitive's inputs vector now carries the
extra array. This commit detects that in fast::Quantize::eval_gpu and:

  - binds inputs[1] (for quantize) or inputs[2] (for dequantize) at
    Metal buffer slot 4
  - appends '_g1' to the kernel name to select the fp_quantize_3tier /
    fp_dequantize_3tier variants from commit 2
  - passes 'quantize_3tier' / 'dequantize_3tier' as the func name so
    get_quantized_kernel_wrapped builds the right template_def

The 3-tier path is bit-identical to numpy reference for the spec-compliant
case, modulo a known issue with simdgroup-lane indexing in the shared
impl helper that affects both 2-tier and 3-tier paths (fixed in commit 4).

Validation:
  - flag OFF: identical to main (all invariance tests pass)
  - flag ON, no global_scale: identical to main (2-tier path unchanged)
  - flag ON, with global_scale: routes to _g1 kernel, produces 3-tier output
    matching numpy reference to within fp32 round-off (~1e-6 rel_err)
  - matmul-side via mx.qqmm still ignores global_scales — separate commit
The block-scale computation for group_size=16 (NVFP4) used tidx.x to
bisect each simdgroup into two 16-element halves:

    float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
    float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
    scale = tidx.x < 16 ? w_max_l : w_max_r;

However tidx.x is the global thread_position_in_grid, not the position
within the simdgroup. For all simdgroups beyond the first (i.e. when
tidx.x >= 32), the predicate 'tidx.x >= 16' is uniformly true, so the
right-half simd_max reduces over all 32 lanes instead of 16 — taking
the max across two adjacent 16-element blocks rather than one.

The result is that every other block (the right-half blocks) stored a
block scale derived from max(|w|) over 32 elements instead of 16.
When the two adjacent blocks had similar maxima this had negligible
effect; when they had different maxima the right-half block got a
coarser quantization, costing ~5-8 percent additional rel_err per
affected block.

Fix: use simd_lane_id (thread_index_in_simdgroup, 0..31 within each
simdgroup) for the bisection. Add it as a kernel attribute and pass
through to the impl helper. The 2-tier wrapper (fp_quantize) and the
3-tier wrapper (fp_quantize_3tier) both gain the attribute.

Empirical impact on 'decode' shape (14336x4096, heavy-tailed):
  before:  mlx_nvfp4_2tier rel_err = 0.118
  after:   mlx_nvfp4_2tier rel_err = 0.094  (20% improvement)

  before:  3-tier roundtrip rel_err = 0.117 (kernel buggy)
  after:   3-tier roundtrip rel_err = 0.091 (matches numpy reference)

Invariance vs main:
  - affine, mxfp4, mxfp8 modes: bit-identical (those paths use group_size=32,
    which takes the use_mx_scale branch with whole-simdgroup simd_max — not
    affected by the bug)
  - nvfp4 mode: numerically different from main — output is improved (lower
    rel_err vs fp16 baseline), matching spec-compliant numpy reference within
    fp32 round-off

This is a fix to existing 2-tier behavior, not a feature addition. NVFP4
checkpoints produced with main vs this branch will have different packed
bytes; consumers reading both must dequantize and compare in floating point.
Refactor fp_quantize_dequantize into impl helper + 2-tier wrapper + 3-tier
wrapper. Apply the simdgroup-lane fix (commit 4) to the impl helper.

The 3-tier kernel takes two FP32 scalars:
  - gs_x: per-tensor scale of the activation. Used to normalize the
    activation before computing the FP4 block scale (same role as
    gs_w in the weight pack path).
  - gs_w: per-tensor scale of the weight tensor. Multiplied into
    the reconstructed activation value at output.

The 'absorption trick': the activation reconstruction at kernel output
becomes xhat = (dequantized_FP4_value) * gs_x * gs_w, which equals
x * gs_w (with FP4 quantization noise). When the downstream matmul
runs (xhat @ Wq_dequant.T), and Wq_dequant decodes the 3-tier-packed
weights as W/gs_w, the product (x*gs_w) @ (W/gs_w).T = x @ W.T exactly.

This avoids modifying any matmul kernel — the gs_w absorption happens
entirely inside the activation quantize_dequantize step. Only one
kernel needs to know about NVFP4 3-tier in this path.

Host-side changes:
  - quantize_dequantize() helper now takes optional global_scale_x
    and global_scale_w, binding them at buffer slots 4 and 5 when
    both are present, and selecting the _3tier kernel variant.
  - QQMatmul::eval_gpu detects NVFP4 + extra trailing inputs,
    extracts global_scale_x and global_scale_w from inputs[-2]
    and inputs[-1], passes both to quantize_dequantize.

Validation:
  - qqmm 2-tier (no global_scale): bit-identical to main
  - qqmm 3-tier rel_err: matches numpy_3tier_reference modulo FP4
    activation noise (~0.13 on synthetic shapes vs 0.09 for
    quantized_matmul on same data — the gap is intrinsic to W4A4
    vs W4A16; activation FP4 quantization adds noise that 3-tier
    weight scaling cannot recover).
  - Decode-only (M=1) path; prefill still throws NYI

The user-facing inference path (quantized_matmul) doesn't go through
QQMatmul. Commit 6 extends quantized_matmul to accept global_scale_w
for the W4A16 inference case that real LLM serving uses.
The user-facing inference path (mx.quantized_matmul) now accepts an
optional global_scale_w parameter for mode='nvfp4'. This completes
3-tier NVFP4 support for the W4xfp16 case used by LLM serving.

API change (mlx/ops.h, mlx/ops.cpp):
  - quantized_matmul gains std::optional<array> global_scale_w before
    the trailing StreamOrDevice parameter (matches qqmm convention)
  - Validates global_scale_w is only provided for mode='nvfp4'
  - Validates global_scale_w dtype is float32
  - When provided, pushed as trailing element of the inputs vector
  - Three internal callers (vjp, jvp, gather_qmm dispatch) updated to
    pass std::nullopt for the new parameter

Python binding (python/src/ops.cpp + .pyi):
  - global_scale_w added as a keyword-positional argument
  - Documentation updated

Metal backend (mlx/backend/metal/quantized.cpp):
  - QuantizedMatmul::eval_gpu detects NVFP4 + inputs.size()==4 and
    treats inputs[3] as global_scale_w instead of biases
  - Each matmul-dispatch return site is followed by apply_global_scale_w()
    which dispatches the mul_inplace_scalar kernel over the output array
  - Five paths covered: qmm_splitk, qmm, dispatch_qmv, qvm, qvm_split_k

New kernel (mlx/backend/metal/kernels/fp_quantized.h + .metal):
  - mul_inplace_scalar: array *= *scalar_buffer, instantiated for
    float / float16_t / bfloat16_t

CPU / CUDA: signature accepts the parameter but ignores it. Callers
that need NVFP4 3-tier matmul correctness on CPU/CUDA must continue
to use mx.qqmm (CPU/CUDA already wire global_scale_w there).

Validation (on Apple M3 Ultra, heavy-tailed weights):
  shape          2-tier   3-tier   improvement
  (1024,1024)    0.092    0.089    3.8 percent
  (4096,4096)    0.094    0.093    1.6 percent
  (4096,8192)    0.094    0.093    1.9 percent  (channel outliers)
  (14336,4096)  0.094    0.091    3.1 percent

The headline 3-tier quality numbers from the NVFP4 spec (10-25 percent
improvement on outlier-heavy distributions) are partially provided by
commit 4's simdgroup-lane fix to the existing 2-tier path, with this
commit closing the remaining 2-4 percent for spec-compliant 3-tier
output. Both are required to match the format spec exactly.
Mirror of commit 6 for the gather_qmm op, which handles MoE-style
batched-expert matmul. Without this, NVFP4-quantized MoE models
(Mixtral-style, MiMo experts, DeepSeek MoE) silently produce 2-tier
output even when 3-tier weights are stored.

API change (mlx/ops.h, mlx/ops.cpp):
  - gather_qmm gains std::optional<array> global_scale_w before the
    trailing StreamOrDevice parameter (matches the qqmm /
    quantized_matmul convention)
  - Validates global_scale_w is only provided for mode='nvfp4'
  - Validates global_scale_w dtype is float32
  - Non-Affine inputs vector: [x, w, scales, (gs_w?), lhs_idx, rhs_idx]
  - One internal caller (VJP) updated to pass std::nullopt

Python binding (python/src/ops.cpp):
  - global_scale_w added as a keyword-positional argument
  - .pyi auto-regenerates from nb::sig

Metal backend (mlx/backend/metal/quantized.cpp):
  - GatherQMM::eval_gpu detects NVFP4 + inputs.size()==6 and treats
    inputs[3] as global_scale_w (not biases — biases never present
    for NVFP4)
  - apply_global_scale_w() lambda applies the gs_w post-multiply via
    mul_inplace_scalar kernel (reused from commit 6)
  - Inserted at all four matmul-dispatch paths:
      gather_qmm_rhs, gather_qmm, gather_qmv, gather_qvm

CPU / CUDA: signature accepts the parameter but currently ignores it.
Validation on synthetic 4-expert MoE batch:
  shape (batch=8, expert outdim=1024, indim=1024, n_experts=4):
    2-tier rel_err:  0.094
    3-tier rel_err:  0.092 (2.3 percent improvement)

The improvement is modest on synthetic uniform-expert data. On real
MoE checkpoints with channel-outlier weight distributions in specific
expert layers, the improvement is comparable to the non-MoE case
(typically 2-5 percent additional rel_err reduction).
…ation

Two related changes in one commit:

== 1. Fix latent fp16 overflow in QuantizedMatmul / GatherQMM ==

Commits 6 and 7 applied global_scale_w by post-multiplying the matmul
output. For typical NVFP4 weights with gs_w ~ 1e-5, the matmul's
intermediate output (x @ W.T)/gs_w can reach ~1e5 times the final
magnitude. With fp16 outputs this overflows fp16's max of 65504,
producing inf values that the post-multiply by gs_w cannot recover.

Fix: PRE-multiply x by gs_w before the matmul. The matmul then
computes (x*gs_w) @ (W/gs_w).T = x @ W.T directly. Since gs_w is
small, x*gs_w stays well within fp16 range.

Implementation: new mul_scalar_copy kernel (separate input/output
arrays) instantiated for float/float16_t/bfloat16_t. The pre-scaling
allocates a temporary x_scaled, runs the kernel, then substitutes
x_scaled for x in the matmul dispatch.

Applied to:
  - QuantizedMatmul::eval_gpu (5 dispatch paths: qmm, qmm_splitk,
    dispatch_qmv, qvm, qvm_split_k)
  - GatherQMM::eval_gpu (4 dispatch paths: gather_qmm_rhs, gather_qmm,
    gather_qmv, gather_qvm)

The post-multiply apply_global_scale_w lambda becomes a no-op (kept
as named symbol so the dispatch return sites remain readable). The
mul_inplace_scalar kernel from commit 6 is now unused but kept in
the metallib (harmless).

== 2. nn.QuantizedLinear integration ==

  - __init__ computes global_scale from random init weight for
    mode='nvfp4', passes to mx.quantize, stores as self.global_scale
  - from_linear computes global_scale from the source Linear weight
  - __call__ passes self.get('global_scale') to quantized_matmul as
    global_scale_w (None for non-nvfp4 modes -> 2-tier behavior)

Validation (heavy-tailed decode shape 14336x4096):
  before: nn.QuantizedLinear with nvfp4 -> ~18 percent inf in output
  after:  rel_err 0.091 vs fp16 reference, matches numpy_3tier_reference
          to 4 decimal places

All harness tests pass invariance for non-nvfp4 modes.
When mode='nvfp4', compute a per-tensor FP32 global_scale in __init__
and from_embedding, store it as self.global_scale, and pass it to both
mx.dequantize (in __call__ for embedding lookup) and mx.quantized_matmul
(in as_linear for tied output projection).

Without this, NVFP4-quantized embedding tables produced 2-tier output
even when stored with 3-tier packing. The embedding lookup path uses
mx.dequantize which already accepted global_scale (commits 1-4). The
as_linear path uses mx.quantized_matmul which now accepts global_scale_w
(commits 6, 8).

For non-nvfp4 modes, behavior unchanged.

Validation:
  - embedding lookup: rel_err 0.097 vs fp16 (normal FP4 noise)
  - as_linear: rel_err 0.091 vs fp16 (matches numpy_3tier_reference)
  - no inf/nan in either path
For mode='nvfp4', compute and store the per-tensor weight global_scale
during quantize(), threaded through dequantize() to round-trip cleanly,
and compute the activation global_scale on the fly in __call__ for the
qqmm dispatch.

Three sites updated:
  - quantize(): for nvfp4, computes global_scale_w from self.weight
    and passes to mx.quantize. Stores as self.global_scale_w.
  - dequantize(): for nvfp4, passes self.global_scale_w to mx.dequantize
    so the inverse round-trip recovers the original weight.
  - __call__: for nvfp4, computes global_scale_x on the fly from |x|
    amax, passes both gs_x and gs_w to mx.qqmm.

For non-nvfp4 modes (mxfp8), behavior unchanged: no global_scale_w
attribute, no global_scale kwargs passed.

Validation (M=1 decode shape, the only path Metal's QQMatmul currently
supports — M>1 throws NYI which is a pre-existing limitation):
  - NVFP4 QQLinear M=1:  rel_err 0.136 vs fp16 linear (normal W4A4 noise)
  - mxfp8 QQLinear M=1:  rel_err 0.094 (unchanged, no global_scale_w)
  - no inf/nan in either path
…le codes

- Drop CPU/CUDA-only conditional in test_nvfp4_quantize_dequantize so the
  3-tier path is now exercised on Metal too. Use spec-compliant global_scale
  (max|W| / (E2M1_MAX * E4M3_MAX)) per NVIDIA NVFP4 specification.
- Add test_nvfp4_block_scales_spec_compliant pinning a 4-block reference
  output that would have caught the simdgroup-lane indexing bug. Locks in
  spec-compliance against future regression.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Metal] NVFP4: 3-tier global_scale unsupported + 2-tier simdgroup-lane indexing bug

1 participant