Skip to content

Add 1-bit affine quantization support (Metal)#3161

Open
khosravipasha wants to merge 2 commits intoml-explore:mainfrom
khosravipasha:1bit-affine-quantization
Open

Add 1-bit affine quantization support (Metal)#3161
khosravipasha wants to merge 2 commits intoml-explore:mainfrom
khosravipasha:1bit-affine-quantization

Conversation

@khosravipasha
Copy link

Add 1-bit affine quantization support (Metal)

Proposed changes

This PR adds 1-bit support to MLX's affine quantization mode, extending the supported bit-widths from {2, 3, 4, 5, 6, 8} to {1, 2, 3, 4, 5, 6, 8}.

MLX already supports affine quantization at 2, 3, 4, 5, 6, and 8 bits via w_hat = scale * w_q + bias. This PR extends that same framework to 1-bit, adding full kernel support for 1-bit affine dequantization and quantized matmul across CPU and Metal backends.

This assumes the model has already been quantized externally (e.g. during training) — the contribution here is efficient packing and inference on Apple Silicon. It supports packing for both affine and symmetric 1-bit weights:

Affine 1-bit — weights have arbitrary per-group min/max:

scale = w_max - w_min,  bias = w_min
bit 0 → w_min,  bit 1 → w_max

Symmetric 1-bit — weights are {-d, +d} per group, automatically handled by the affine formula above since w_min = -d, w_max = +d:

scale = w_max - w_min = 2d,  bias = w_min = -d
bit 0 → 0·(2d) + (-d) = -d
bit 1 → 1·(2d) + (-d) = +d

A dedicated symmetric 1-bit mode (scale only, no bias) could save memory and skip the bias addition in the matmul kernels, but for now both cases run through the same affine path.

What's included

  • CPU backend: 1-bit quantize, dequantize, and quantized matmul (qmm dispatch)
  • Metal backend: Full 1-bit support in all quantized kernels — both non-NAX (quantized.h) and NAX (quantized_nax.h, quantized_nax.metal) paths, plus quantize/dequantize kernels
  • Python bindings: Updated mx.quantize(w, bits=1), mx.dequantize(...), and mx.quantized_matmul(...) documentation
  • Unit tests: Added 1-bit to test_quantize_dequantize, test_qmm, and a dedicated test_1bit_quantize_dequantize covering round-trip accuracy, zero handling, and quantized matmul correctness. Full test suite passes (672 tests, 0 failures).
  • No CUDA support: 1-bit is not yet supported on the CUDA backend. The CUDA dispatch_bits does not include a case 1: path. The new 1-bit test is added to cuda_skip.py.

Expected model-level performance (hypothetical, 8B parameter model, Apple M4 Pro 48 GB)

Based on the kernel-level benchmarks below, a hypothetical 8B parameter model at 1-bit would see roughly (varying by group size due to scale/bias metadata overhead):

Configuration Memory Expected Throughput
FP16 (baseline) ~15.3 GB ~15 tok/s
1-bit (group size 128) ~1.3 GB ~100–130 tok/s
1-bit (group size 64) ~1.6 GB ~90–115 tok/s
1-bit (group size 32) ~2.0 GB ~80–100 tok/s

We verified one scenario (group size 128, all weights quantized) and observed throughput in the ballpark of the estimates above. The primary purpose of this table is to give a sense of the runtime speed that 1-bit quantization enables. These are back-of-the-envelope numbers — actual end-to-end performance will vary depending on which layers are quantized, group size, attention overhead, and other non-quantized computation.

Kernel Corretness validation (KL divergence, 8B parameter model, WikiText-2)

To validate matmul kernel correctness, we compared two runs of the same 1-bit quantized model: one using the quantized matmul kernels (weights stay packed in 1-bit), and the other with the 1-bit weights dequantized to FP16 first and run through standard FP16 matmul. This is not a comparison between an FP16 model and its quantized version — both sides use identical weight values, so any divergence would indicate a kernel bug. Both the prompt processing (qmm) and token generation (qmv) paths were tested.

Prompt processing (qmm path) — 20 WikiText-2 chunks:

Metric Value
Forward KL(P||Q) Mean 0.000024
Reverse KL(Q||P) Mean 0.000017
Mean Top-1 Agreement 99.85%
Min Top-1 Agreement 99.29%

Token generation (qmv path) — 113 autoregressive steps (single-token qmv) across 5 prompts:

Metric Value
Forward KL(P||Q) Mean 0.000067
Reverse KL(Q||P) Mean -0.000038
Mean Top-1 Agreement 100.0%
Min Top-1 Agreement 100%

Both forward and reverse KL are near-zero, confirming the quantized kernels produce results consistent with the dequantized FP16 reference in both qmm and qmv code paths.

Changes

  • mlx/backend/cpu/quantized.cpp - 1-bit quantization logic and qmm dispatch
  • mlx/backend/metal/kernels/quantized.h - Metal 1-bit load_vector, qdot, qdot_safe, qouter, dequantize
  • mlx/backend/metal/kernels/quantized_nax.h - Same for NAX kernels
  • mlx/ops.cpp - Validation to accept bits=1
  • python/src/ops.cpp - Updated docstring table
  • python/tests/test_quantized.py - Added 1-bit to existing tests + dedicated 1-bit test
  • python/tests/cuda_skip.py - Skip 1-bit test on CUDA
  • benchmarks/python/comparative/bench_mlx.py - Added 1-bit entries to quant_matmul dict; auto-quantizes weight from --size args
  • benchmarks/python/comparative/compare.py - Added quant_matmul benchmark entries comparing 1/2/4/8-bit across qmv and qmm paths

Notes

  1. The Metal qmv_quad_impl kernel has a minor edge case with 1-bit when the inner dimension is < 128. In practice this should never come up — virtually all models have dimensions well above 128.

  2. If all weights in a group are exactly 0, the affine 1-bit quantization computes scale = eps (floored) and bias = 0, which dequantizes all values to near-zero (correct behavior).

  3. Kernel-level quantized_matmul benchmarks (Apple M4 Pro 48 GB, GPU, NAX path, group_size=128, 1000 calls, weight shape in parentheses):

    qmv path (M=1, single-token generation, memory-bandwidth bound):

    Layer FP16 1-bit 2-bit 4-bit 1-bit speedup vs FP16
    attn_proj (4096×4096) 167 µs 27 µs 39 µs 38 µs 6.2×
    ffn_gate (11008×4096) 391 µs 46 µs 72 µs 110 µs 8.6×
    ffn_down (4096×11008) 409 µs 59 µs 75 µs 98 µs 6.9×

    qmm path (M=32, prompt processing, more compute-bound):

    Layer FP16 1-bit 2-bit 4-bit 1-bit speedup vs FP16
    attn_proj (4096×4096) 308 µs 178 µs 179 µs 182 µs 1.7×
    ffn_gate (11008×4096) 841 µs 430 µs 438 µs 433 µs 2.0×
    ffn_down (4096×11008) 649 µs 441 µs 435 µs 440 µs 1.5×

    1-bit entries have been added to benchmarks/python/comparative/bench_mlx.py and compare.py. To reproduce (from repo root):

    # run all quant_matmul benchmarks (1/2/4/8-bit, qmv M=1, qmm M=32 & M=512)
    python benchmarks/python/comparative/compare.py --filter quant_matmul
    
    # or run individual benchmarks
    python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_128_1 --size 1x4096 --size 4096x4096
    python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_128_4 --size 1x4096 --size 4096x4096
    
    # unit tests
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_1bit_quantize_dequantize -v
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_quantize_dequantize -v
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_qmm -v

Questions for reviewers

  1. NAX vs non-NAX testing: All benchmarks and the full test suite were run on macOS 26.2 (M4 Pro 48 GB), where NAX is active. The non-NAX path was partially validated by rebuilding with -DCMAKE_CXX_FLAGS=-DMLX_METAL_NO_NAX — unit tests pass, but full benchmarking was only done on the NAX path. We only have access to an M4. Is the -DMLX_METAL_NO_NAX build flag sufficient to validate the non-NAX path, or would you recommend testing on actual older hardware (M1/M2/M3)?

  2. Test coverage: The full test suite passes (672 tests, 0 failures), including dedicated 1-bit tests for both symmetric and asymmetric weight round-trip accuracy, zero handling, and quantized matmul correctness across both qmm and qmv paths. Is there any additional testing you'd like to see before merging?

Future work

  • Dedicated symmetric 1-bit mode
  • CUDA support

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copilot AI review requested due to automatic review settings February 24, 2026 04:57
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds 1-bit affine quantization support to MLX, extending the existing quantization framework from {2, 3, 4, 5, 6, 8} bits to include 1-bit. The implementation provides efficient packing and inference for 1-bit quantized weights on Apple Silicon (Metal) and CPU backends.

Changes:

  • Added 1-bit support to affine quantization with formula: scale = w_max - w_min, bias = w_min where bit 0 → w_min, bit 1 → w_max
  • Implemented full Metal kernel support for 1-bit in both NAX and non-NAX paths across all quantized operations (quantize, dequantize, qmm, qmv)
  • Extended CPU backend with 1-bit quantization, dequantization, and quantized matmul dispatch
  • Added comprehensive test coverage for 1-bit symmetric/asymmetric weights, zero handling, and quantized matmul correctness
  • Updated Python bindings documentation and validation to accept bits=1
  • Added 1-bit benchmark entries for performance comparison across different group sizes
  • Excluded CUDA backend from 1-bit support (added to cuda_skip.py)

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
python/tests/test_quantized.py Added 1-bit to existing parameterized tests and new dedicated test_1bit_quantize_dequantize with symmetric/asymmetric weights, zero handling, and qmm/qmv correctness tests
python/tests/cuda_skip.py Added 1-bit test to CUDA skip list since CUDA backend doesn't support 1-bit
python/src/ops.cpp Updated quantization mode documentation table to include 1-bit in supported bits
mlx/ops.cpp Modified validation to accept bits >= 1 and added 1-bit quantization formula (scale = w_max - w_min, bias = w_min)
mlx/backend/metal/kernels/quantized_nax.metal Added 1-bit kernel instantiation macro for NAX path
mlx/backend/metal/kernels/quantized_nax.h Implemented 1-bit versions of load_vector, load_vector_safe, qdot, qdot_safe, qouter, and dequantize for NAX optimized kernels
mlx/backend/metal/kernels/quantized.metal Added 1-bit kernel instantiation macro and quantize/dequantize logic for non-NAX path
mlx/backend/metal/kernels/quantized.h Implemented 1-bit versions of all quantization primitives for non-NAX kernels
mlx/backend/cpu/quantized.cpp Added 1-bit case to qmm dispatch and quantization logic matching Metal implementation
benchmarks/python/comparative/compare.py Added compare_mlx_quant function and 1-bit benchmark entries for qmv and qmm paths
benchmarks/python/comparative/bench_mlx.py Added 1-bit entries to quant_matmul dictionary for all group sizes and transpose modes, plus auto-quantization logic

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@angeloskath
Copy link
Member

Hi @khosravipasha that is pretty cool. I am not sure we want to support 1-bit quants natively in MLX, even 2 bits are not really used out there.

The options I see are:

  1. You could leave it as an open PR for people to chime in if they want this or think it would be useful in any way
  2. You could make it an extension that we 'd be happy to link to at MLX Community Projects #654 and anybody could simply pip install it and use it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants