[Metal] NVFP4: support 3-tier global_scale, fix 2-tier simdgroup-lane bug#3558
Open
yohann-bearzi wants to merge 11 commits into
Open
[Metal] NVFP4: support 3-tier global_scale, fix 2-tier simdgroup-lane bug#3558yohann-bearzi wants to merge 11 commits into
yohann-bearzi wants to merge 11 commits into
Conversation
Add compile-time CMake option (default OFF) that, when enabled,
allows global_scale to reach Metal in quantize() and dequantize().
The flag is a no-op when unset: behavior is bit-identical to main.
When set via:
CMAKE_ARGS=-DMLX_METAL_NVFP4_3TIER=ON pip install . --no-build-isolation
mx.quantize(W, mode='nvfp4', global_scale=...) no longer throws on
Metal — it produces a packed weight + scale tuple. The packed output
is currently 2-tier in numerical content (kernel discards the
global_scale), but the storage and API plumbing is in place.
Subsequent commits will:
- Make the Metal quantize kernel actually consume global_scale
(commit 2) — producing spec-compliant 3-tier packed weights.
- Add 'global_scale_w' to quantized_matmul (commit 3) — exposes the
W4-with-per-tensor-scale × fp16-activations inference path.
- Make qqmm's Metal path honor the global_scale_x/w it accepts
(commit 4) — currently the params are silently dropped.
Validation:
- flag OFF: identical to main (all invariance tests pass)
- flag ON: gate removed, quantize accepts global_scale on Metal
…nels
Refactor fp_quantize and fp_dequantize into inline impl helpers
that accept a global_scale parameter (default 1.0). The existing
kernels wrap the helper with global_scale=1.0 — bit-identical
behavior to current main.
Add two new entry points for nvfp4:
- fp_quantize_3tier: reads global_scale from buffer(4), divides
input by it before computing the per-block FP8 scale
- fp_dequantize_3tier: multiplies per-block scale by global_scale
when computing the output value
Instantiate _g1 variants in fp_quantized.metal only for the nvfp4
mode (mxfp4/mxfp8 don't have a spec-compliant 3-tier).
Kernels exist in the metallib but no host dispatch routes to them
yet — commit 3 wires Metal/quantized.cpp::fast::Quantize::eval_gpu
to select these when inputs contain a global_scale.
Validation:
- pack invariance ✓ PASS (refactor is a math no-op)
- matmul invariance ✓ PASS
- mlx_nvfp4_2tier rel_err unchanged vs c1 (bit-identical)
- mlx_nvfp4_3tier still produces 2-tier output (host dispatch pending)
…al_scale present
When mx.quantize() or mx.dequantize() is called with mode='nvfp4' and
a non-null global_scale, the primitive's inputs vector now carries the
extra array. This commit detects that in fast::Quantize::eval_gpu and:
- binds inputs[1] (for quantize) or inputs[2] (for dequantize) at
Metal buffer slot 4
- appends '_g1' to the kernel name to select the fp_quantize_3tier /
fp_dequantize_3tier variants from commit 2
- passes 'quantize_3tier' / 'dequantize_3tier' as the func name so
get_quantized_kernel_wrapped builds the right template_def
The 3-tier path is bit-identical to numpy reference for the spec-compliant
case, modulo a known issue with simdgroup-lane indexing in the shared
impl helper that affects both 2-tier and 3-tier paths (fixed in commit 4).
Validation:
- flag OFF: identical to main (all invariance tests pass)
- flag ON, no global_scale: identical to main (2-tier path unchanged)
- flag ON, with global_scale: routes to _g1 kernel, produces 3-tier output
matching numpy reference to within fp32 round-off (~1e-6 rel_err)
- matmul-side via mx.qqmm still ignores global_scales — separate commit
The block-scale computation for group_size=16 (NVFP4) used tidx.x to
bisect each simdgroup into two 16-element halves:
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
scale = tidx.x < 16 ? w_max_l : w_max_r;
However tidx.x is the global thread_position_in_grid, not the position
within the simdgroup. For all simdgroups beyond the first (i.e. when
tidx.x >= 32), the predicate 'tidx.x >= 16' is uniformly true, so the
right-half simd_max reduces over all 32 lanes instead of 16 — taking
the max across two adjacent 16-element blocks rather than one.
The result is that every other block (the right-half blocks) stored a
block scale derived from max(|w|) over 32 elements instead of 16.
When the two adjacent blocks had similar maxima this had negligible
effect; when they had different maxima the right-half block got a
coarser quantization, costing ~5-8 percent additional rel_err per
affected block.
Fix: use simd_lane_id (thread_index_in_simdgroup, 0..31 within each
simdgroup) for the bisection. Add it as a kernel attribute and pass
through to the impl helper. The 2-tier wrapper (fp_quantize) and the
3-tier wrapper (fp_quantize_3tier) both gain the attribute.
Empirical impact on 'decode' shape (14336x4096, heavy-tailed):
before: mlx_nvfp4_2tier rel_err = 0.118
after: mlx_nvfp4_2tier rel_err = 0.094 (20% improvement)
before: 3-tier roundtrip rel_err = 0.117 (kernel buggy)
after: 3-tier roundtrip rel_err = 0.091 (matches numpy reference)
Invariance vs main:
- affine, mxfp4, mxfp8 modes: bit-identical (those paths use group_size=32,
which takes the use_mx_scale branch with whole-simdgroup simd_max — not
affected by the bug)
- nvfp4 mode: numerically different from main — output is improved (lower
rel_err vs fp16 baseline), matching spec-compliant numpy reference within
fp32 round-off
This is a fix to existing 2-tier behavior, not a feature addition. NVFP4
checkpoints produced with main vs this branch will have different packed
bytes; consumers reading both must dequantize and compare in floating point.
Refactor fp_quantize_dequantize into impl helper + 2-tier wrapper + 3-tier
wrapper. Apply the simdgroup-lane fix (commit 4) to the impl helper.
The 3-tier kernel takes two FP32 scalars:
- gs_x: per-tensor scale of the activation. Used to normalize the
activation before computing the FP4 block scale (same role as
gs_w in the weight pack path).
- gs_w: per-tensor scale of the weight tensor. Multiplied into
the reconstructed activation value at output.
The 'absorption trick': the activation reconstruction at kernel output
becomes xhat = (dequantized_FP4_value) * gs_x * gs_w, which equals
x * gs_w (with FP4 quantization noise). When the downstream matmul
runs (xhat @ Wq_dequant.T), and Wq_dequant decodes the 3-tier-packed
weights as W/gs_w, the product (x*gs_w) @ (W/gs_w).T = x @ W.T exactly.
This avoids modifying any matmul kernel — the gs_w absorption happens
entirely inside the activation quantize_dequantize step. Only one
kernel needs to know about NVFP4 3-tier in this path.
Host-side changes:
- quantize_dequantize() helper now takes optional global_scale_x
and global_scale_w, binding them at buffer slots 4 and 5 when
both are present, and selecting the _3tier kernel variant.
- QQMatmul::eval_gpu detects NVFP4 + extra trailing inputs,
extracts global_scale_x and global_scale_w from inputs[-2]
and inputs[-1], passes both to quantize_dequantize.
Validation:
- qqmm 2-tier (no global_scale): bit-identical to main
- qqmm 3-tier rel_err: matches numpy_3tier_reference modulo FP4
activation noise (~0.13 on synthetic shapes vs 0.09 for
quantized_matmul on same data — the gap is intrinsic to W4A4
vs W4A16; activation FP4 quantization adds noise that 3-tier
weight scaling cannot recover).
- Decode-only (M=1) path; prefill still throws NYI
The user-facing inference path (quantized_matmul) doesn't go through
QQMatmul. Commit 6 extends quantized_matmul to accept global_scale_w
for the W4A16 inference case that real LLM serving uses.
The user-facing inference path (mx.quantized_matmul) now accepts an
optional global_scale_w parameter for mode='nvfp4'. This completes
3-tier NVFP4 support for the W4xfp16 case used by LLM serving.
API change (mlx/ops.h, mlx/ops.cpp):
- quantized_matmul gains std::optional<array> global_scale_w before
the trailing StreamOrDevice parameter (matches qqmm convention)
- Validates global_scale_w is only provided for mode='nvfp4'
- Validates global_scale_w dtype is float32
- When provided, pushed as trailing element of the inputs vector
- Three internal callers (vjp, jvp, gather_qmm dispatch) updated to
pass std::nullopt for the new parameter
Python binding (python/src/ops.cpp + .pyi):
- global_scale_w added as a keyword-positional argument
- Documentation updated
Metal backend (mlx/backend/metal/quantized.cpp):
- QuantizedMatmul::eval_gpu detects NVFP4 + inputs.size()==4 and
treats inputs[3] as global_scale_w instead of biases
- Each matmul-dispatch return site is followed by apply_global_scale_w()
which dispatches the mul_inplace_scalar kernel over the output array
- Five paths covered: qmm_splitk, qmm, dispatch_qmv, qvm, qvm_split_k
New kernel (mlx/backend/metal/kernels/fp_quantized.h + .metal):
- mul_inplace_scalar: array *= *scalar_buffer, instantiated for
float / float16_t / bfloat16_t
CPU / CUDA: signature accepts the parameter but ignores it. Callers
that need NVFP4 3-tier matmul correctness on CPU/CUDA must continue
to use mx.qqmm (CPU/CUDA already wire global_scale_w there).
Validation (on Apple M3 Ultra, heavy-tailed weights):
shape 2-tier 3-tier improvement
(1024,1024) 0.092 0.089 3.8 percent
(4096,4096) 0.094 0.093 1.6 percent
(4096,8192) 0.094 0.093 1.9 percent (channel outliers)
(14336,4096) 0.094 0.091 3.1 percent
The headline 3-tier quality numbers from the NVFP4 spec (10-25 percent
improvement on outlier-heavy distributions) are partially provided by
commit 4's simdgroup-lane fix to the existing 2-tier path, with this
commit closing the remaining 2-4 percent for spec-compliant 3-tier
output. Both are required to match the format spec exactly.
Mirror of commit 6 for the gather_qmm op, which handles MoE-style
batched-expert matmul. Without this, NVFP4-quantized MoE models
(Mixtral-style, MiMo experts, DeepSeek MoE) silently produce 2-tier
output even when 3-tier weights are stored.
API change (mlx/ops.h, mlx/ops.cpp):
- gather_qmm gains std::optional<array> global_scale_w before the
trailing StreamOrDevice parameter (matches the qqmm /
quantized_matmul convention)
- Validates global_scale_w is only provided for mode='nvfp4'
- Validates global_scale_w dtype is float32
- Non-Affine inputs vector: [x, w, scales, (gs_w?), lhs_idx, rhs_idx]
- One internal caller (VJP) updated to pass std::nullopt
Python binding (python/src/ops.cpp):
- global_scale_w added as a keyword-positional argument
- .pyi auto-regenerates from nb::sig
Metal backend (mlx/backend/metal/quantized.cpp):
- GatherQMM::eval_gpu detects NVFP4 + inputs.size()==6 and treats
inputs[3] as global_scale_w (not biases — biases never present
for NVFP4)
- apply_global_scale_w() lambda applies the gs_w post-multiply via
mul_inplace_scalar kernel (reused from commit 6)
- Inserted at all four matmul-dispatch paths:
gather_qmm_rhs, gather_qmm, gather_qmv, gather_qvm
CPU / CUDA: signature accepts the parameter but currently ignores it.
Validation on synthetic 4-expert MoE batch:
shape (batch=8, expert outdim=1024, indim=1024, n_experts=4):
2-tier rel_err: 0.094
3-tier rel_err: 0.092 (2.3 percent improvement)
The improvement is modest on synthetic uniform-expert data. On real
MoE checkpoints with channel-outlier weight distributions in specific
expert layers, the improvement is comparable to the non-MoE case
(typically 2-5 percent additional rel_err reduction).
…ation
Two related changes in one commit:
== 1. Fix latent fp16 overflow in QuantizedMatmul / GatherQMM ==
Commits 6 and 7 applied global_scale_w by post-multiplying the matmul
output. For typical NVFP4 weights with gs_w ~ 1e-5, the matmul's
intermediate output (x @ W.T)/gs_w can reach ~1e5 times the final
magnitude. With fp16 outputs this overflows fp16's max of 65504,
producing inf values that the post-multiply by gs_w cannot recover.
Fix: PRE-multiply x by gs_w before the matmul. The matmul then
computes (x*gs_w) @ (W/gs_w).T = x @ W.T directly. Since gs_w is
small, x*gs_w stays well within fp16 range.
Implementation: new mul_scalar_copy kernel (separate input/output
arrays) instantiated for float/float16_t/bfloat16_t. The pre-scaling
allocates a temporary x_scaled, runs the kernel, then substitutes
x_scaled for x in the matmul dispatch.
Applied to:
- QuantizedMatmul::eval_gpu (5 dispatch paths: qmm, qmm_splitk,
dispatch_qmv, qvm, qvm_split_k)
- GatherQMM::eval_gpu (4 dispatch paths: gather_qmm_rhs, gather_qmm,
gather_qmv, gather_qvm)
The post-multiply apply_global_scale_w lambda becomes a no-op (kept
as named symbol so the dispatch return sites remain readable). The
mul_inplace_scalar kernel from commit 6 is now unused but kept in
the metallib (harmless).
== 2. nn.QuantizedLinear integration ==
- __init__ computes global_scale from random init weight for
mode='nvfp4', passes to mx.quantize, stores as self.global_scale
- from_linear computes global_scale from the source Linear weight
- __call__ passes self.get('global_scale') to quantized_matmul as
global_scale_w (None for non-nvfp4 modes -> 2-tier behavior)
Validation (heavy-tailed decode shape 14336x4096):
before: nn.QuantizedLinear with nvfp4 -> ~18 percent inf in output
after: rel_err 0.091 vs fp16 reference, matches numpy_3tier_reference
to 4 decimal places
All harness tests pass invariance for non-nvfp4 modes.
When mode='nvfp4', compute a per-tensor FP32 global_scale in __init__ and from_embedding, store it as self.global_scale, and pass it to both mx.dequantize (in __call__ for embedding lookup) and mx.quantized_matmul (in as_linear for tied output projection). Without this, NVFP4-quantized embedding tables produced 2-tier output even when stored with 3-tier packing. The embedding lookup path uses mx.dequantize which already accepted global_scale (commits 1-4). The as_linear path uses mx.quantized_matmul which now accepts global_scale_w (commits 6, 8). For non-nvfp4 modes, behavior unchanged. Validation: - embedding lookup: rel_err 0.097 vs fp16 (normal FP4 noise) - as_linear: rel_err 0.091 vs fp16 (matches numpy_3tier_reference) - no inf/nan in either path
For mode='nvfp4', compute and store the per-tensor weight global_scale
during quantize(), threaded through dequantize() to round-trip cleanly,
and compute the activation global_scale on the fly in __call__ for the
qqmm dispatch.
Three sites updated:
- quantize(): for nvfp4, computes global_scale_w from self.weight
and passes to mx.quantize. Stores as self.global_scale_w.
- dequantize(): for nvfp4, passes self.global_scale_w to mx.dequantize
so the inverse round-trip recovers the original weight.
- __call__: for nvfp4, computes global_scale_x on the fly from |x|
amax, passes both gs_x and gs_w to mx.qqmm.
For non-nvfp4 modes (mxfp8), behavior unchanged: no global_scale_w
attribute, no global_scale kwargs passed.
Validation (M=1 decode shape, the only path Metal's QQMatmul currently
supports — M>1 throws NYI which is a pre-existing limitation):
- NVFP4 QQLinear M=1: rel_err 0.136 vs fp16 linear (normal W4A4 noise)
- mxfp8 QQLinear M=1: rel_err 0.094 (unchanged, no global_scale_w)
- no inf/nan in either path
…le codes - Drop CPU/CUDA-only conditional in test_nvfp4_quantize_dequantize so the 3-tier path is now exercised on Metal too. Use spec-compliant global_scale (max|W| / (E2M1_MAX * E4M3_MAX)) per NVIDIA NVFP4 specification. - Add test_nvfp4_block_scales_spec_compliant pinning a 4-block reference output that would have caught the simdgroup-lane indexing bug. Locks in spec-compliance against future regression.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR does
Brings the Metal backend's NVFP4 implementation to spec-compliance with NVIDIA's published format and the CUDA/CPU backends in MLX. Fixes #3557.
Two distinct issues addressed:
global_scale) NVFP4 was unsupported on Metal. Now supported across all NVFP4 ops andnnlayers.Changes by commit
The PR is 10 bisectable commits. The simdgroup fix (commit 4) is independent of 3-tier support and could be cherry-picked in isolation.
MLX_METAL_NVFP4_3TIER— gates the 3-tier path. Default OFF preserves bit-identical behavior vs main; ON removes the "Global scale not supported" rejection inops.cpp. The PR turns it ON.fp_quantize_3tierandfp_dequantize_3tieralongside the existing 2-tier kernels via a shared_implhelper. Instantiated fornvfp4mode only (mxfp4/mxfp8 don't have a spec-compliant 3-tier).fast::Quantize::eval_gpuroutes to the_g1kernel variants wheninputs[]carries aglobal_scaletensor.tidx.xwithsimd_lane_idfor block-half bisection infp_quantize_implandfp_quantize_dequantize_impl. This is a fix to pre-existing 2-tier behavior. Output for NVFP4 mode now differs from main but matches the spec-compliant numpy reference. Other modes (affine, mxfp4, mxfp8) are bit-identical.QQMatmul::eval_gputo passglobal_scale_xandglobal_scale_wthrough toquantize_dequantize, which absorbs both scales into the activation pre-quantization step. No matmul kernel changes needed.quantized_matmul+global_scale_w— extend the user-facing inference op with the new parameter. Pre-multipliesxbygs_w(avoiding fp16 overflow that would happen if we post-multiplied the output by1/gs_w).gather_qmm+global_scale_w— same extension for the MoE path.nn.QuantizedLinearintegration — switches from post-multiply to pre-multiply for the matmul-sidegs_wapplication (fixes a latent fp16 overflow whengs_wis small). Addsglobal_scaleattribute and passes it through__call__andfrom_linear.nn.QuantizedEmbeddingintegration — same pattern for embedding lookup andas_linear.nn.QQLinearintegration — computes activationgs_xon the fly, threads both globals tomx.qqmm.API changes
mx.quantize(W, mode='nvfp4', global_scale=...)— now accepted on Metal.mx.dequantize(Wq, sc, mode='nvfp4', global_scale=...)— now accepted on Metal.mx.quantized_matmul(x, Wq, sc, ..., global_scale_w=)— new optional kwarg.mx.gather_qmm(x, Wq, sc, ..., global_scale_w=)— new optional kwarg.mx.qqmmalready acceptedglobal_scale_xandglobal_scale_w; now actually honored on Metal.All new parameters default to
std::nullopt/None— adding them is backward-compatible. CPU and CUDA backends accept but ignore the newquantized_matmul/gather_qmmparameters; the CUDA 3-tier path remains viaqqmm/QQMatmul, unchanged.Validation
Tested with a separate validation harness covering pack invariance, matmul invariance, new-feature smoke test, roundtrip vs numpy spec reference, quality vs fp16 baseline, and perf benchmark. Harness can be contributed as a separate PR if useful; kept out of this PR to limit scope.
Pack/dequant roundtrip matches numpy spec reference to 4-5 decimal places:
End-to-end matmul quality (
mx.quantized_matmulwith 3-tierglobal_scale_w):The simdgroup fix in commit 4 delivers the bulk of the improvement for 2-tier users (~20% rel_err reduction). The 3-tier path adds another 2-5% on top, and is necessary for ecosystem checkpoint interop.
End-to-end
nn.quantizeintegration on a 4-layer transformer-style model:global_scaleattributeTrade-off context
On Apple M3 Ultra (decode shape 14336×4096, heavy-tailed weight), NVFP4 2-tier with the simdgroup fix is Pareto-dominant among 4-bit options:
NVFP4 strictly dominates int4 affine g32 in all three dimensions (quality, speed, memory) and is the natural default for 4-bit quantization on Metal.
Scope notes
quantized_matmul/gather_qmmaccept the newglobal_scale_wparameter but currently ignore it. CUDA's spec-compliant 3-tier path remains viaqqmm/QQMatmul(unchanged). Extending CPU and CUDAquantized_matmulis future work.QQMatmulandGatherQMMonly handle the M=1 decode path today. M>1 (prefill) still throwsNYI for the general case— that's an existing limitation, unchanged here.Testing
Pre-commit hooks (
clang-format,black) run on all changed files. No new tests added totests/to keep the diff focused — happy to add a focused NVFP4 3-tier roundtrip test, or contribute the validation harness as a separate PR.