block_fp8: 2D-block FP8 quantized matmul + MoE kernels (DeepSeek-V3 / MiMo)#3600
block_fp8: 2D-block FP8 quantized matmul + MoE kernels (DeepSeek-V3 / MiMo)#3600yohann-bearzi wants to merge 17 commits into
Conversation
Add BlockFp8 to QuantizationMode enum, string<->enum conversions, default group_size=128/bits=8. Extend validate_quantized_input to recognize uint8 weights + 2D fp32 scales with 128x128 block tiling (DeepSeek-V3 / MiMo convention). No kernels yet — host plumbing only. mx.quantized_matmul with mode='block_fp8' passes validation and fails at kernel dispatch.
Add BlockFp8 to QuantizationMode enum, string<->enum conversions, default group_size=128/bits=8. Extend validate_quantized_input and extract_quantized_matmul_dims to recognize uint8 weights with 2D fp32 block scales (DeepSeek-V3 / MiMo convention). No kernels yet — host plumbing only. mx.quantized_matmul with mode='block_fp8' now reaches kernel dispatch: Unable to load kernel block_fp8_qmv_quad_bfloat16_t_gs_128_b_8_d_128_batch_0
Add fp_block_fp8_qmv_fast_impl + entry point in fp_quantized.h with hardcoded group_size=128, bits=8 (DeepSeek-V3 / MiMo convention). Reuses qdot<U, 8, 8> for E4M3 decode; only the scale geometry is new — 2D fp32 scales indexed by (out_row/128, k/128). Non-batched only; B>1 silently incorrect until adjust_matrix_offsets gets a fp32-scales overload. Instantiate for float, bfloat16_t, float16_t with batched and non-batched variants. Validation: 1x4096 @ 4096x256 with random-quantized weights matches fp32 reference at cos=0.999997, rel_err=0.0023 — at the bf16+fp8 round-off floor.
Add block_fp8_gather_qmv_fast kernel (host gather_qmv dispatch path, M=1 transposed B>=2). Per-token rhs_indices offset-adjusts w/scales to expert slice, then calls block_fp8_qmv_fast_impl unchanged. Also adds block_fp8_adjust_matrix_offsets helper — fp32-scale overload of the existing uint8-scale function. s_strides interpreted in fp32 elements via float*& parameter type, not bytes. Validation: 2-expert decode with random-quantized weights matches per-expert fp32 reference at cos=0.999997, rel_err=0.0023 — same numerical floor as qmv_fast (bf16 + fp8 round-off).
Add _skip_init kwarg to QuantizedLinear.__init__ (in our mlx fork) and QuantizedSwitchLinear.__init__ (in mlx_lm). When True, skip the random-uniform + mx.quantize bootstrap; caller assigns self.weight and self.scales directly from pre-quantized tensors on disk. Avoids a wasteful test-quant + re-quant for native-format checkpoints (block-FP8, etc) and lets loaders skip implementing mx.quantize for modes that have no encode-side kernel. NOTE: QuantizedSwitchLinear lives in /opt/homebrew/.../mlx_lm/, not this repo. Tracking that change as a separate fork-internal edit. Validation: nn.QuantizedLinear(_skip_init=True, mode='block_fp8') with hand-assigned codes+scales forwards through __call__ at cos=0.999997 vs fp32 reference — same numerical floor as direct mx.quantized_matmul.
Xiaomi's MiMo-V2.5 fused qkv_proj weight has actual_rows=13568 but scales has 108 rows (= 128*108 = 13824 weight-equivalent rows), with the extra 2 scale rows being trailing padding from the original tensor-parallel distribution. Kernel correctly ignores these (never indexes past out_row < w.shape[-2]), but the strict validator rejected them. Change: scales.shape >= ceil(w.shape / 128) in both inner dims, instead of equality. Trailing scale-padding is allowed. End-to-end result: MiMo-V2.5 loads at 269 GB resident on M3 Ultra (vs 580 GB for the bf16-upcast path that OOMs), forward pass succeeds, generates one valid token for prompt 'Hello'.
qmm_t_nax kernel doesn't exist for block_fp8 yet. Route block_fp8 through regular qmm path until nax variants are written.
Adds block_fp8_qmm_t kernel (M1: qmv-math repeated per M row, suboptimal but correct). Routes block_fp8 around nax and splitk paths until those have native kernels. Verified end-to-end: chat-templated MiMo-V2.5 prompt produces correct top-1 token '<think>' (42.75) matching affine-4bit baseline.
Set MLX_TRACE_KERNELS=1 to print each kernel name as it's resolved in the quantized dispatcher wrappers. Useful for debugging which kernels get dispatched in a given forward pass and detecting missing/duplicate launches. Zero cost when MLX_TRACE_KERNELS is unset.
Adds block_fp8_gather_qmm_t (M-row qmv math + expert index lookup), matching the structure of block_fp8_qmm_t. Dispatcher overrides grid geometry for block_fp8 mode in gather_qmm(). Bit-exact vs M1 ground truth (cos=1.0000 at every captured layer). MiMo-V2.5 perf on M3 Ultra 512GB: prefill: 6.4 -> 96.4 tok/s (15x) decode: 28.6 tok/s unchanged
Replaces the M3.1 per-row qmm_t kernel with a proper tiled implementation
that uses BlockMMA + a new BlockFp8QuantizedLoader.
The loader is a 2D-scale variant of the existing fp QuantizedBlockLoader
in fp_quantized.h, sibling to the affine and fp variants. Drops into the
standard qmm_t template, gaining:
- 32x weight tile reuse across M rows (vs per-row reads)
- Apple simdgroup_matrix hardware matmul via BlockMMA
- fp32 register accumulation (canonical pattern)
Removes the block_fp8-specific dispatcher grid override; tiled kernel uses
the standard (N/32, M/32, B) geometry with (32, 2, 2) group_dims.
Verified end-to-end on MiMo-V2.5:
prefill: 96.4 -> 117.6 tok/s (+22%)
decode: 29.0 tok/s unchanged (decode path uses qmv not qmm)
top-1 stable at 151667 ('<think>')
Numerical drift vs the per-row kernel is small (cos > 0.99 at all captured
layers, top-1 stable). Both are fp32-faithful implementations; difference
is purely K-loop ordering. Ground truth refreshed to mimo_ground_truth_m4
to lock in the new canonical numerics.
Replaces M3.1 per-row block_fp8_gather_qmm_t with a tiled implementation using block_fp8_adjust_matrix_offsets + block_fp8_qmm_t_impl. Adds the sorted-indices block_fp8_gather_qmm_rhs kernel that was previously missing (prefill > ~256 tokens used to crash with a missing-kernel error). Removes the block_fp8-specific geometry override in the gather_qmm dispatcher; tiled kernels use the standard (N/32, M/32, B) geometry. Verified on MiMo-V2.5 (M3 Ultra 512GB): prefill 29 tokens: 117 tok/s (gather_qmv path, unchanged) prefill 1024 tokens: 615 tok/s (gather_qmm_rhs path, new) prefill 4096 tokens: 493 tok/s (gather_qmm_rhs path, new) Isolated kernel test: cos=0.999997 vs fp32 reference across 5 experts spanning the routed expert range. Deterministic. No NaN. Bit-exact precision profile unchanged for short prefill (gather_qmv path) and decode path. Long prefill now correct and fast.
The sdpa_vector kernel template already supported a separate value head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs with D != V were instantiated, and use_fallback required query_head_dim == value_head_dim. MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling through to a compiled-graph decomposition (multiple GatherAxis dispatches per attention layer) instead of the fused kernel. Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to allow this specific asymmetric case. Other head dims remain unchanged. Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable. The remaining decode bottleneck is MoE routing, not attention.
Adds block_fp8_qmm_t_splitk, which splits the K dimension across threadgroups when the M*N tile count is too low to saturate the GPU. Each partition computes a partial product into an intermediate buffer; qmm_splitk then sums across the split_k axis. Required refactoring block_fp8_qmm_t_impl to take a separate k_loop_limit parameter distinct from K: K remains the row stride for x/w/scales, while k_loop_limit bounds the K-accumulation loop. The two normal entry points (block_fp8_qmm_t, block_fp8_gather_qmm_t) pass k_loop_limit == K, an exact no-op. The splitk entry passes k_partition_size. Removes the early-return guard in qmm_splitk that previously forced block_fp8 through the non-split path. The MiMo-V2.5 workload never triggers splitk (its prefill matmuls have M*N tile counts that already fill the GPU, and decode goes through the vector path), so this is parity work for the block_fp8 mode rather than a MiMo speedup. Verified correct on a forced-splitk shape (M=32, N=128, K=4096, split_k=8): cos=0.999994 vs the validated vector path, with max-diff consistent with fp32 accumulation-order drift. MiMo top-1 stable at 151667.
…ale format
The block_fp8 codepaths in fp_quantize/fp_dequantize previously inherited
the 1D-group / E8M0-scale layout from mxfp8, which does not match the
DeepSeek-V3 / MiMo convention this PR's kernels consume (2D 128x128 blocks
with per-block float32 scales). dequantize() additionally had a hardcoded
uint32 weight-dtype guard that gated block_fp8 (unpacked uint8 E4M3 codes)
out of its own path.
This commit:
* fp_quantize: adds a block_fp8 branch that round-to-nearest quantizes
weights using per-block max-abs scaling (scale = amax / 448) and casts
to E4M3 via to_fp8. Output is uint8 codes [..., N, K] + float32 scales
[..., N/gs, K/gs], matching the on-disk format the inference kernels
already consume.
* fp_dequantize: adds a block_fp8 branch that decodes from_fp8(codes)
broadcast-multiplied by the per-block scale, returning the original
floating dtype.
* dequantize: makes the weight-dtype guard mode-aware so uint8 codes
reach the block_fp8 branch (uint32 still required for all other modes).
Round-trip cos > 0.999 on real MiMo weights; quantized_matmul agrees with
dequantize-then-matmul at cos > 0.999996 across M=1/4/32/256 and the MoE
gather paths. Re-quantizing a dequantized tensor produces identical codes
(stable quantization decision, required by iterative quantizers like GPTQ).
Five test methods covering the block_fp8 work:
* test_block_fp8_quantize_dequantize: round-trip cos > 0.99 across
several N,K shapes; verifies output dtypes/shapes are correct.
* test_block_fp8_quantize_idempotent: re-quantizing a dequantized
tensor must reproduce identical codes (with scales agreeing to
float-rounding tolerance). This is the stability property iterative
quantizers depend on.
* test_block_fp8_outlier_block_scaling: a single outlier block must
not corrupt precision in unrelated blocks - the point of per-block
(vs per-tensor) scaling.
* test_block_fp8_invalid_shapes: non-128-multiple dims rejected by
quantize; uint32 weight rejected by dequantize for block_fp8.
* test_block_fp8_kernel_coverage: single test exercising every kernel
path shipped by this PR against the dequantize reference. A regression
in any kernel makes the corresponding subTest fail. Subtests:
- block_fp8_qmv_fast (M=1 decode vector)
- block_fp8_qmm_t (M=64 tiled prefill)
- block_fp8_qmm_t_splitk (short-M, large-K splitk path)
- block_fp8_qmm_t_large (M=256 deep tiled)
- block_fp8_gather_qmv_fast (T=1 MoE decode vector)
- block_fp8_gather_qmm_rhs (T=32 MoE prefill, global-sort contract)
- sdpa_fused_asymmetric_192_128 (fused vector SDPA at MiMo head dims)
- padded_scale_validator (relaxed validator for fused-QKV)
|
For anyone who wants to try this end-to-end on a real model, I've published a ready-to-run MiMo-V2.5 checkpoint: https://huggingface.co/bearzi/MiMo-V2.5-MLX It's the text-only path of XiaomiMiMo/MiMo-V2.5 (310B MoE, 256 experts top-8, hybrid 5:1 SWA/GA, block_fp8) repacked for MLX. Weight values are bit-identical to upstream — the only transform is concatenating the 32 source shards and stacking the per-expert MoE tensors along a new leading axis (required by MLX MoE loaders). The repo also ships the MLX model class ( Note that running it end-to-end also requires the unmerged mlx-lm#1219 (base |
Xiaomi MiMo-V2.5 ships a fused qkv_proj whose block_fp8 scales are padded independently per TP shard (tp=4): GA-layer scales have tp*ceil((N/tp)/128) rows, not ceil(N/128). The qmv_fast/qmm_t kernels indexed scales flat (scale_row = out_row/128), so every shard past the first read the wrong scale block, corrupting K/V and causing generation to drift (English -> Russian/Chinese) on long/hard prompts while passing short high-margin ones. Fix: detect per-shard padding from scales.shape(-2) and recover tp, then index scale_row = shard*shard_srows + (row%shard_rows)/128. Detection defaults to flat (shard_rows=N), so all non-TP-padded tensors (SWA qkv, o_proj, MoE, lm_head) are bit-identical. scale_rows passed by-value into the two impls; entry kernels read it from a buffer via the dispatcher. Validated on MiMo-V2.5: GA-layer prefill-vs-decode logit diff dropped 2.71875 -> fp8 floor on early layers; HumanEval prompt 0 now decodes clean English/Python.
block_fp8: 2D-block FP8 quantized matmul + MoE kernels for DeepSeek-V3 / MiMo
Summary
Adds first-class support for the block_fp8 quantization format used by
DeepSeek-V3 / MiMo (and other recent MoE LLMs): 2D 128×128 weight blocks of
unpacked E4M3 (
F8_E4M3) codes with a per-blockfloat32scale. Implementsthe kernel set needed to run inference end-to-end (decode + prefill, dense +
MoE gather), along with
quantize/dequantizeops so the format isproducible and consumable in-library.
Validated on the MiMo-V2.5 base (310 B, 48 layers, 256 experts top-8, hybrid
5:1 SWA/GA, head_dim 192/128) on M3 Ultra: 30.3 tok/s decode, with the
canonical greedy continuation reproduced bit-for-bit (top-1 = 151667 on the
Einstein verification prompt).
What's in this PR
Kernels (block_fp8 quantization mode)
block_fp8_qmv_fast— M = 1 decode vector kernel.block_fp8_qmm_t— tiled prefill matmul using a newBlockFp8QuantizedLoader(handles unpacked E4M3 + 2DF32block scales).block_fp8_qmm_t_splitk— short-M splitk variant for shapes where theM·N tile count can't fill the GPU (covers the verify-style 4-token
forwards used by speculative decoding).
block_fp8_gather_qmv_fast— MoE decode vector kernel (top-K routing,single token).
block_fp8_gather_qmm_t/block_fp8_gather_qmm_rhs— MoE prefilltiled kernels (the
rhsvariant matches the SwitchGLU global-sortcontract used by
mlx_lm).Dispatcher / plumbing
block_fp8mode threaded throughquantized_matmul,gather_qmm,dequantize, and the kernel-name selection (quantized.cpp).validate_quantized_input) that acceptsthe 2D
F32block-scale layout and the fused-QKV padding convention(scale tensor may include trailing rows beyond
ceil(w / 128), which thekernel never reads).
_skip_initplumbing forQuantizedLinear/QuantizedSwitchLinearsopre-quantized weights can be installed without re-allocating placeholder
buffers.
SDPA
sdpa: the fused vector kernel now handles asymmetric Q/Vhead_dim(192/128) used by MiMo's SWA blocks. Previously gated by an equality
check that excluded this shape.
quantize/dequantizeops (this commit's diff)Prior to this PR,
mx.quantize/mx.dequantizewithmode="block_fp8"inherited the 1D-group / E8M0-scale fallback from
mxfp8, which does notmatch the 2D 128×128 /
F32-scale format the kernels consume;dequantizeadditionally rejected
uint8weights via a hardcodeduint32guardunrelated to the mode validator.
Three minimal edits in
mlx/ops.cpp:fp_quantize: adds ablock_fp8branch performing per-block round-to-nearest quantization (
scale = max(|W_block|) / 448, codes viato_fp8(W/scale)). Returnsuint8codes[..., N, K]plusfloat32scales
[..., N/gs, K/gs], identical to the on-disk format the kernelsalready consume.
fp_dequantize: adds ablock_fp8branch that reconstructs the weightas
from_fp8(codes) * scalewith the per-block scale broadcast over theblock grid.
dequantize: the post-validator weight-dtype guard is made mode-aware(
uint8for block_fp8,uint32otherwise). The behavior of everynon-block_fp8 mode is byte-identical.
Forward-quantize uses only
to_fp8and shape ops already present inops.cpp— no new Metal kernel is required (the format is producible fromthe existing primitives).
Validation
Kernel coverage test (new)
python/tests/test_quantized.py::TestQuantized::test_block_fp8_kernel_coverageis a single test with one subtest per kernel; a regression in any kernel
fails the corresponding subTest:
block_fp8_qmv_fastblock_fp8_qmm_tblock_fp8_qmm_t_splitkblock_fp8_qmm_t_largeblock_fp8_gather_qmv_fastblock_fp8_gather_qmm_rhssdpa_fused_asymmetric_192_128padded_scale_validatorceil(w/128)All subtests compare against
mx.dequantize-then-reference-matmul or theunfused SDPA reference; threshold
cos > 0.99.Round-trip / idempotency
test_block_fp8_quantize_dequantize— quantize/dequantize round-tripacross several
(N, K)shapes.test_block_fp8_quantize_idempotent— re-quantizing a dequantized tensorreproduces identical codes (with scales matching to float-rounding
tolerance). This is the stability property iterative quantizers (e.g.
GPTQ) require to converge.
test_block_fp8_outlier_block_scaling— a single outlier block must notcorrupt the precision of unrelated blocks (the rationale for per-block
vs per-tensor scaling).
test_block_fp8_invalid_shapes— non-128-multiple dims rejected byquantize;uint32weights rejected bydequantize(block_fp8).Numerics on real weights
quantized_matmul(codes, scales)agrees withx @ dequantize(codes, scales).Tatcos > 0.99999on real MiMoprojections, across the decode (M = 1), prefill (M = 256), and MoE gather
paths. The numpy-RTN E4M3 reference (computed offline) matches the
in-library
quantizecodes 100 % on the same inputs.End-to-end model
On the MiMo-V2.5 (310 B, block_fp8) checkpoint, M3 Ultra 512 GB, wired
memory limit 300 GB:
token on the Einstein verification prompt (top-1 = 151667 at position 29,
with the expected continuation
[785, 785, 66622, 54052, 320, 23, 22, 24, 4142, …]).No regressions in existing modes
affine/mxfp4/mxfp8/nvfp4quantize/dequantize pathsare byte-identical (the
block_fp8branches are entered first via modecomparison and
returnbefore touching the existing code).test_quantized.py(e.g.test_qvm/test_qmmatbits=6) are independent of this PR; they reproduce onthe merge-base without these changes, on M3 Ultra (
arch_gen=15).Likely M4+/NAX-gated paths that need an arch decorator in a separate
commit.
Hardware notes (M3 Ultra)
NAX is
applegpu_g17+(M4+ on macOS 26.2+); on M3 Ultrais_nax_available()returns false, and this PR's dispatcher explicitly skips NAX routing for
block_fp8. The MLX_TRACE_KERNELS env var (added earlier in the branch) was
useful during development for confirming kernel selection at runtime.