Skip to content

block_fp8: 2D-block FP8 quantized matmul + MoE kernels (DeepSeek-V3 / MiMo)#3600

Open
yohann-bearzi wants to merge 17 commits into
ml-explore:mainfrom
yohann-bearzi:fp8-block-mvp
Open

block_fp8: 2D-block FP8 quantized matmul + MoE kernels (DeepSeek-V3 / MiMo)#3600
yohann-bearzi wants to merge 17 commits into
ml-explore:mainfrom
yohann-bearzi:fp8-block-mvp

Conversation

@yohann-bearzi
Copy link
Copy Markdown

@yohann-bearzi yohann-bearzi commented May 28, 2026

block_fp8: 2D-block FP8 quantized matmul + MoE kernels for DeepSeek-V3 / MiMo

Summary

Adds first-class support for the block_fp8 quantization format used by
DeepSeek-V3 / MiMo (and other recent MoE LLMs): 2D 128×128 weight blocks of
unpacked E4M3 (F8_E4M3) codes with a per-block float32 scale. Implements
the kernel set needed to run inference end-to-end (decode + prefill, dense +
MoE gather), along with quantize / dequantize ops so the format is
producible and consumable in-library.

Validated on the MiMo-V2.5 base (310 B, 48 layers, 256 experts top-8, hybrid
5:1 SWA/GA, head_dim 192/128) on M3 Ultra: 30.3 tok/s decode, with the
canonical greedy continuation reproduced bit-for-bit (top-1 = 151667 on the
Einstein verification prompt).

What's in this PR

Kernels (block_fp8 quantization mode)

  • block_fp8_qmv_fast — M = 1 decode vector kernel.
  • block_fp8_qmm_t — tiled prefill matmul using a new
    BlockFp8QuantizedLoader (handles unpacked E4M3 + 2D F32 block scales).
  • block_fp8_qmm_t_splitk — short-M splitk variant for shapes where the
    M·N tile count can't fill the GPU (covers the verify-style 4-token
    forwards used by speculative decoding).
  • block_fp8_gather_qmv_fast — MoE decode vector kernel (top-K routing,
    single token).
  • block_fp8_gather_qmm_t / block_fp8_gather_qmm_rhs — MoE prefill
    tiled kernels (the rhs variant matches the SwitchGLU global-sort
    contract used by mlx_lm).

Dispatcher / plumbing

  • block_fp8 mode threaded through quantized_matmul, gather_qmm,
    dequantize, and the kernel-name selection (quantized.cpp).
  • 2D-block-aware shape validator (validate_quantized_input) that accepts
    the 2D F32 block-scale layout and the fused-QKV padding convention
    (scale tensor may include trailing rows beyond ceil(w / 128), which the
    kernel never reads).
  • _skip_init plumbing for QuantizedLinear / QuantizedSwitchLinear so
    pre-quantized weights can be installed without re-allocating placeholder
    buffers.

SDPA

  • sdpa: the fused vector kernel now handles asymmetric Q/V head_dim
    (192/128) used by MiMo's SWA blocks. Previously gated by an equality
    check that excluded this shape.

quantize / dequantize ops (this commit's diff)

Prior to this PR, mx.quantize / mx.dequantize with mode="block_fp8"
inherited the 1D-group / E8M0-scale fallback from mxfp8, which does not
match the 2D 128×128 / F32-scale format the kernels consume; dequantize
additionally rejected uint8 weights via a hardcoded uint32 guard
unrelated to the mode validator.

Three minimal edits in mlx/ops.cpp:

  • fp_quantize: adds a block_fp8 branch performing per-block round-to-
    nearest quantization (scale = max(|W_block|) / 448, codes via
    to_fp8(W/scale)). Returns uint8 codes [..., N, K] plus float32
    scales [..., N/gs, K/gs], identical to the on-disk format the kernels
    already consume.
  • fp_dequantize: adds a block_fp8 branch that reconstructs the weight
    as from_fp8(codes) * scale with the per-block scale broadcast over the
    block grid.
  • dequantize: the post-validator weight-dtype guard is made mode-aware
    (uint8 for block_fp8, uint32 otherwise). The behavior of every
    non-block_fp8 mode is byte-identical.

Forward-quantize uses only to_fp8 and shape ops already present in
ops.cpp — no new Metal kernel is required (the format is producible from
the existing primitives).

Validation

Kernel coverage test (new)

python/tests/test_quantized.py::TestQuantized::test_block_fp8_kernel_coverage
is a single test with one subtest per kernel; a regression in any kernel
fails the corresponding subTest:

subTest exercises
block_fp8_qmv_fast M = 1 decode vector
block_fp8_qmm_t M = 64 tiled prefill
block_fp8_qmm_t_splitk M = 4, K = 4096 short-M splitk path
block_fp8_qmm_t_large M = 256 large tiled
block_fp8_gather_qmv_fast T = 1, KP = 4 MoE decode
block_fp8_gather_qmm_rhs T = 32, KP = 4 MoE prefill, global-sort contract
sdpa_fused_asymmetric_192_128 fused vector SDPA at MiMo head dims
padded_scale_validator scales tensor with trailing rows beyond ceil(w/128)

All subtests compare against mx.dequantize-then-reference-matmul or the
unfused SDPA reference; threshold cos > 0.99.

Round-trip / idempotency

  • test_block_fp8_quantize_dequantize — quantize/dequantize round-trip
    across several (N, K) shapes.
  • test_block_fp8_quantize_idempotent — re-quantizing a dequantized tensor
    reproduces identical codes (with scales matching to float-rounding
    tolerance). This is the stability property iterative quantizers (e.g.
    GPTQ) require to converge.
  • test_block_fp8_outlier_block_scaling — a single outlier block must not
    corrupt the precision of unrelated blocks (the rationale for per-block
    vs per-tensor scaling).
  • test_block_fp8_invalid_shapes — non-128-multiple dims rejected by
    quantize; uint32 weights rejected by dequantize(block_fp8).

Numerics on real weights

quantized_matmul(codes, scales) agrees with
x @ dequantize(codes, scales).T at cos > 0.99999 on real MiMo
projections, across the decode (M = 1), prefill (M = 256), and MoE gather
paths. The numpy-RTN E4M3 reference (computed offline) matches the
in-library quantize codes 100 % on the same inputs.

End-to-end model

On the MiMo-V2.5 (310 B, block_fp8) checkpoint, M3 Ultra 512 GB, wired
memory limit 300 GB:

  • decode 30.3 tok/s (single-stream, batch = 1, no speculative decoding)
  • prefill ≈ 525 tok/s at L = 2048
  • greedy continuation reproduces the bf16-reference continuation token-for-
    token on the Einstein verification prompt (top-1 = 151667 at position 29,
    with the expected continuation [785, 785, 66622, 54052, 320, 23, 22, 24, 4142, …]).

No regressions in existing modes

  • Existing affine / mxfp4 / mxfp8 / nvfp4 quantize/dequantize paths
    are byte-identical (the block_fp8 branches are entered first via mode
    comparison and return before touching the existing code).
  • Pre-existing test failures in test_quantized.py (e.g. test_qvm /
    test_qmm at bits=6) are independent of this PR; they reproduce on
    the merge-base without these changes, on M3 Ultra (arch_gen=15).
    Likely M4+/NAX-gated paths that need an arch decorator in a separate
    commit.

Hardware notes (M3 Ultra)

NAX is applegpu_g17+ (M4+ on macOS 26.2+); on M3 Ultra is_nax_available()
returns false, and this PR's dispatcher explicitly skips NAX routing for
block_fp8. The MLX_TRACE_KERNELS env var (added earlier in the branch) was
useful during development for confirming kernel selection at runtime.

Add BlockFp8 to QuantizationMode enum, string<->enum conversions,
default group_size=128/bits=8. Extend validate_quantized_input to
recognize uint8 weights + 2D fp32 scales with 128x128 block tiling
(DeepSeek-V3 / MiMo convention). No kernels yet — host plumbing
only. mx.quantized_matmul with mode='block_fp8' passes validation
and fails at kernel dispatch.
Add BlockFp8 to QuantizationMode enum, string<->enum conversions,
default group_size=128/bits=8. Extend validate_quantized_input and
extract_quantized_matmul_dims to recognize uint8 weights with 2D
fp32 block scales (DeepSeek-V3 / MiMo convention). No kernels yet
— host plumbing only. mx.quantized_matmul with mode='block_fp8'
now reaches kernel dispatch:
  Unable to load kernel block_fp8_qmv_quad_bfloat16_t_gs_128_b_8_d_128_batch_0
Add fp_block_fp8_qmv_fast_impl + entry point in fp_quantized.h with
hardcoded group_size=128, bits=8 (DeepSeek-V3 / MiMo convention).
Reuses qdot<U, 8, 8> for E4M3 decode; only the scale geometry is
new — 2D fp32 scales indexed by (out_row/128, k/128). Non-batched
only; B>1 silently incorrect until adjust_matrix_offsets gets a
fp32-scales overload.

Instantiate for float, bfloat16_t, float16_t with batched and
non-batched variants.

Validation: 1x4096 @ 4096x256 with random-quantized weights
matches fp32 reference at cos=0.999997, rel_err=0.0023 — at the
bf16+fp8 round-off floor.
Add block_fp8_gather_qmv_fast kernel (host gather_qmv dispatch path,
M=1 transposed B>=2). Per-token rhs_indices offset-adjusts w/scales
to expert slice, then calls block_fp8_qmv_fast_impl unchanged.

Also adds block_fp8_adjust_matrix_offsets helper — fp32-scale overload
of the existing uint8-scale function. s_strides interpreted in fp32
elements via float*& parameter type, not bytes.

Validation: 2-expert decode with random-quantized weights matches
per-expert fp32 reference at cos=0.999997, rel_err=0.0023 — same
numerical floor as qmv_fast (bf16 + fp8 round-off).
Add _skip_init kwarg to QuantizedLinear.__init__ (in our mlx fork) and
QuantizedSwitchLinear.__init__ (in mlx_lm). When True, skip the
random-uniform + mx.quantize bootstrap; caller assigns self.weight
and self.scales directly from pre-quantized tensors on disk.

Avoids a wasteful test-quant + re-quant for native-format checkpoints
(block-FP8, etc) and lets loaders skip implementing mx.quantize for
modes that have no encode-side kernel.

NOTE: QuantizedSwitchLinear lives in /opt/homebrew/.../mlx_lm/, not
this repo. Tracking that change as a separate fork-internal edit.

Validation: nn.QuantizedLinear(_skip_init=True, mode='block_fp8') with
hand-assigned codes+scales forwards through __call__ at cos=0.999997
vs fp32 reference — same numerical floor as direct mx.quantized_matmul.
Xiaomi's MiMo-V2.5 fused qkv_proj weight has actual_rows=13568 but
scales has 108 rows (= 128*108 = 13824 weight-equivalent rows), with
the extra 2 scale rows being trailing padding from the original
tensor-parallel distribution. Kernel correctly ignores these (never
indexes past out_row < w.shape[-2]), but the strict validator
rejected them.

Change: scales.shape >= ceil(w.shape / 128) in both inner dims,
instead of equality. Trailing scale-padding is allowed.

End-to-end result: MiMo-V2.5 loads at 269 GB resident on M3 Ultra
(vs 580 GB for the bf16-upcast path that OOMs), forward pass
succeeds, generates one valid token for prompt 'Hello'.
qmm_t_nax kernel doesn't exist for block_fp8 yet. Route block_fp8
through regular qmm path until nax variants are written.
Adds block_fp8_qmm_t kernel (M1: qmv-math repeated per M row, suboptimal
but correct). Routes block_fp8 around nax and splitk paths until those
have native kernels.

Verified end-to-end: chat-templated MiMo-V2.5 prompt produces correct
top-1 token '<think>' (42.75) matching affine-4bit baseline.
Set MLX_TRACE_KERNELS=1 to print each kernel name as it's resolved in
the quantized dispatcher wrappers. Useful for debugging which kernels
get dispatched in a given forward pass and detecting missing/duplicate
launches.

Zero cost when MLX_TRACE_KERNELS is unset.
Adds block_fp8_gather_qmm_t (M-row qmv math + expert index lookup),
matching the structure of block_fp8_qmm_t. Dispatcher overrides grid
geometry for block_fp8 mode in gather_qmm().

Bit-exact vs M1 ground truth (cos=1.0000 at every captured layer).

MiMo-V2.5 perf on M3 Ultra 512GB:
  prefill: 6.4 -> 96.4 tok/s  (15x)
  decode:  28.6 tok/s unchanged
Replaces the M3.1 per-row qmm_t kernel with a proper tiled implementation
that uses BlockMMA + a new BlockFp8QuantizedLoader.

The loader is a 2D-scale variant of the existing fp QuantizedBlockLoader
in fp_quantized.h, sibling to the affine and fp variants. Drops into the
standard qmm_t template, gaining:
  - 32x weight tile reuse across M rows (vs per-row reads)
  - Apple simdgroup_matrix hardware matmul via BlockMMA
  - fp32 register accumulation (canonical pattern)

Removes the block_fp8-specific dispatcher grid override; tiled kernel uses
the standard (N/32, M/32, B) geometry with (32, 2, 2) group_dims.

Verified end-to-end on MiMo-V2.5:
  prefill: 96.4 -> 117.6 tok/s  (+22%)
  decode:  29.0 tok/s unchanged (decode path uses qmv not qmm)
  top-1 stable at 151667 ('<think>')

Numerical drift vs the per-row kernel is small (cos > 0.99 at all captured
layers, top-1 stable). Both are fp32-faithful implementations; difference
is purely K-loop ordering. Ground truth refreshed to mimo_ground_truth_m4
to lock in the new canonical numerics.
Replaces M3.1 per-row block_fp8_gather_qmm_t with a tiled implementation
using block_fp8_adjust_matrix_offsets + block_fp8_qmm_t_impl. Adds the
sorted-indices block_fp8_gather_qmm_rhs kernel that was previously missing
(prefill > ~256 tokens used to crash with a missing-kernel error).

Removes the block_fp8-specific geometry override in the gather_qmm
dispatcher; tiled kernels use the standard (N/32, M/32, B) geometry.

Verified on MiMo-V2.5 (M3 Ultra 512GB):
  prefill 29 tokens:    117 tok/s (gather_qmv path, unchanged)
  prefill 1024 tokens:  615 tok/s (gather_qmm_rhs path, new)
  prefill 4096 tokens:  493 tok/s (gather_qmm_rhs path, new)

Isolated kernel test: cos=0.999997 vs fp32 reference across 5 experts
spanning the routed expert range. Deterministic. No NaN.

Bit-exact precision profile unchanged for short prefill (gather_qmv path)
and decode path. Long prefill now correct and fast.
The sdpa_vector kernel template already supported a separate value
head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs
with D != V were instantiated, and use_fallback required
query_head_dim == value_head_dim.

MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling
through to a compiled-graph decomposition (multiple GatherAxis dispatches
per attention layer) instead of the fused kernel.

Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to
allow this specific asymmetric case. Other head dims remain unchanged.

Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable.
The remaining decode bottleneck is MoE routing, not attention.
Adds block_fp8_qmm_t_splitk, which splits the K dimension across
threadgroups when the M*N tile count is too low to saturate the GPU.
Each partition computes a partial product into an intermediate buffer;
qmm_splitk then sums across the split_k axis.

Required refactoring block_fp8_qmm_t_impl to take a separate
k_loop_limit parameter distinct from K: K remains the row stride for
x/w/scales, while k_loop_limit bounds the K-accumulation loop. The two
normal entry points (block_fp8_qmm_t, block_fp8_gather_qmm_t) pass
k_loop_limit == K, an exact no-op. The splitk entry passes
k_partition_size.

Removes the early-return guard in qmm_splitk that previously forced
block_fp8 through the non-split path.

The MiMo-V2.5 workload never triggers splitk (its prefill matmuls have
M*N tile counts that already fill the GPU, and decode goes through the
vector path), so this is parity work for the block_fp8 mode rather than
a MiMo speedup. Verified correct on a forced-splitk shape (M=32, N=128,
K=4096, split_k=8): cos=0.999994 vs the validated vector path, with
max-diff consistent with fp32 accumulation-order drift. MiMo top-1
stable at 151667.
…ale format

The block_fp8 codepaths in fp_quantize/fp_dequantize previously inherited
the 1D-group / E8M0-scale layout from mxfp8, which does not match the
DeepSeek-V3 / MiMo convention this PR's kernels consume (2D 128x128 blocks
with per-block float32 scales). dequantize() additionally had a hardcoded
uint32 weight-dtype guard that gated block_fp8 (unpacked uint8 E4M3 codes)
out of its own path.

This commit:

  * fp_quantize: adds a block_fp8 branch that round-to-nearest quantizes
    weights using per-block max-abs scaling (scale = amax / 448) and casts
    to E4M3 via to_fp8. Output is uint8 codes [..., N, K] + float32 scales
    [..., N/gs, K/gs], matching the on-disk format the inference kernels
    already consume.

  * fp_dequantize: adds a block_fp8 branch that decodes from_fp8(codes)
    broadcast-multiplied by the per-block scale, returning the original
    floating dtype.

  * dequantize: makes the weight-dtype guard mode-aware so uint8 codes
    reach the block_fp8 branch (uint32 still required for all other modes).

Round-trip cos > 0.999 on real MiMo weights; quantized_matmul agrees with
dequantize-then-matmul at cos > 0.999996 across M=1/4/32/256 and the MoE
gather paths. Re-quantizing a dequantized tensor produces identical codes
(stable quantization decision, required by iterative quantizers like GPTQ).
Five test methods covering the block_fp8 work:

  * test_block_fp8_quantize_dequantize: round-trip cos > 0.99 across
    several N,K shapes; verifies output dtypes/shapes are correct.

  * test_block_fp8_quantize_idempotent: re-quantizing a dequantized
    tensor must reproduce identical codes (with scales agreeing to
    float-rounding tolerance). This is the stability property iterative
    quantizers depend on.

  * test_block_fp8_outlier_block_scaling: a single outlier block must
    not corrupt precision in unrelated blocks - the point of per-block
    (vs per-tensor) scaling.

  * test_block_fp8_invalid_shapes: non-128-multiple dims rejected by
    quantize; uint32 weight rejected by dequantize for block_fp8.

  * test_block_fp8_kernel_coverage: single test exercising every kernel
    path shipped by this PR against the dequantize reference. A regression
    in any kernel makes the corresponding subTest fail. Subtests:
      - block_fp8_qmv_fast       (M=1 decode vector)
      - block_fp8_qmm_t          (M=64 tiled prefill)
      - block_fp8_qmm_t_splitk   (short-M, large-K splitk path)
      - block_fp8_qmm_t_large    (M=256 deep tiled)
      - block_fp8_gather_qmv_fast (T=1 MoE decode vector)
      - block_fp8_gather_qmm_rhs  (T=32 MoE prefill, global-sort contract)
      - sdpa_fused_asymmetric_192_128 (fused vector SDPA at MiMo head dims)
      - padded_scale_validator        (relaxed validator for fused-QKV)
@yohann-bearzi
Copy link
Copy Markdown
Author

For anyone who wants to try this end-to-end on a real model, I've published a ready-to-run MiMo-V2.5 checkpoint:

https://huggingface.co/bearzi/MiMo-V2.5-MLX

It's the text-only path of XiaomiMiMo/MiMo-V2.5 (310B MoE, 256 experts top-8, hybrid 5:1 SWA/GA, block_fp8) repacked for MLX. Weight values are bit-identical to upstream — the only transform is concatenating the 32 source shards and stacking the per-expert MoE tensors along a new leading axis (required by MLX MoE loaders). The repo also ships the MLX model class (mimo_v2_block_fp8.py) and the converter script, so anyone with Xiaomi's release can reproduce the file in ~10 minutes.

Note that running it end-to-end also requires the unmerged mlx-lm#1219 (base mimo_v2 model class) — the README in the HF repo documents the full prerequisite chain.

Xiaomi MiMo-V2.5 ships a fused qkv_proj whose block_fp8 scales are padded
independently per TP shard (tp=4): GA-layer scales have tp*ceil((N/tp)/128)
rows, not ceil(N/128). The qmv_fast/qmm_t kernels indexed scales flat
(scale_row = out_row/128), so every shard past the first read the wrong
scale block, corrupting K/V and causing generation to drift (English ->
Russian/Chinese) on long/hard prompts while passing short high-margin ones.

Fix: detect per-shard padding from scales.shape(-2) and recover tp, then
index scale_row = shard*shard_srows + (row%shard_rows)/128. Detection
defaults to flat (shard_rows=N), so all non-TP-padded tensors (SWA qkv,
o_proj, MoE, lm_head) are bit-identical. scale_rows passed by-value into
the two impls; entry kernels read it from a buffer via the dispatcher.

Validated on MiMo-V2.5: GA-layer prefill-vs-decode logit diff dropped
2.71875 -> fp8 floor on early layers; HumanEval prompt 0 now decodes
clean English/Python.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant