Add 1-bit affine quantization support (Metal)#3161
Open
khosravipasha wants to merge 2 commits intoml-explore:mainfrom
Open
Add 1-bit affine quantization support (Metal)#3161khosravipasha wants to merge 2 commits intoml-explore:mainfrom
khosravipasha wants to merge 2 commits intoml-explore:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds 1-bit affine quantization support to MLX, extending the existing quantization framework from {2, 3, 4, 5, 6, 8} bits to include 1-bit. The implementation provides efficient packing and inference for 1-bit quantized weights on Apple Silicon (Metal) and CPU backends.
Changes:
- Added 1-bit support to affine quantization with formula:
scale = w_max - w_min,bias = w_minwhere bit 0 → w_min, bit 1 → w_max - Implemented full Metal kernel support for 1-bit in both NAX and non-NAX paths across all quantized operations (quantize, dequantize, qmm, qmv)
- Extended CPU backend with 1-bit quantization, dequantization, and quantized matmul dispatch
- Added comprehensive test coverage for 1-bit symmetric/asymmetric weights, zero handling, and quantized matmul correctness
- Updated Python bindings documentation and validation to accept bits=1
- Added 1-bit benchmark entries for performance comparison across different group sizes
- Excluded CUDA backend from 1-bit support (added to cuda_skip.py)
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
python/tests/test_quantized.py |
Added 1-bit to existing parameterized tests and new dedicated test_1bit_quantize_dequantize with symmetric/asymmetric weights, zero handling, and qmm/qmv correctness tests |
python/tests/cuda_skip.py |
Added 1-bit test to CUDA skip list since CUDA backend doesn't support 1-bit |
python/src/ops.cpp |
Updated quantization mode documentation table to include 1-bit in supported bits |
mlx/ops.cpp |
Modified validation to accept bits >= 1 and added 1-bit quantization formula (scale = w_max - w_min, bias = w_min) |
mlx/backend/metal/kernels/quantized_nax.metal |
Added 1-bit kernel instantiation macro for NAX path |
mlx/backend/metal/kernels/quantized_nax.h |
Implemented 1-bit versions of load_vector, load_vector_safe, qdot, qdot_safe, qouter, and dequantize for NAX optimized kernels |
mlx/backend/metal/kernels/quantized.metal |
Added 1-bit kernel instantiation macro and quantize/dequantize logic for non-NAX path |
mlx/backend/metal/kernels/quantized.h |
Implemented 1-bit versions of all quantization primitives for non-NAX kernels |
mlx/backend/cpu/quantized.cpp |
Added 1-bit case to qmm dispatch and quantization logic matching Metal implementation |
benchmarks/python/comparative/compare.py |
Added compare_mlx_quant function and 1-bit benchmark entries for qmv and qmm paths |
benchmarks/python/comparative/bench_mlx.py |
Added 1-bit entries to quant_matmul dictionary for all group sizes and transpose modes, plus auto-quantization logic |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Member
|
Hi @khosravipasha that is pretty cool. I am not sure we want to support 1-bit quants natively in MLX, even 2 bits are not really used out there. The options I see are:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add 1-bit affine quantization support (Metal)
Proposed changes
This PR adds 1-bit support to MLX's affine quantization mode, extending the supported bit-widths from
{2, 3, 4, 5, 6, 8}to{1, 2, 3, 4, 5, 6, 8}.MLX already supports affine quantization at 2, 3, 4, 5, 6, and 8 bits via
w_hat = scale * w_q + bias. This PR extends that same framework to 1-bit, adding full kernel support for 1-bit affine dequantization and quantized matmul across CPU and Metal backends.This assumes the model has already been quantized externally (e.g. during training) — the contribution here is efficient packing and inference on Apple Silicon. It supports packing for both affine and symmetric 1-bit weights:
Affine 1-bit — weights have arbitrary per-group min/max:
Symmetric 1-bit — weights are
{-d, +d}per group, automatically handled by the affine formula above sincew_min = -d, w_max = +d:A dedicated symmetric 1-bit mode (scale only, no bias) could save memory and skip the bias addition in the matmul kernels, but for now both cases run through the same affine path.
What's included
qmmdispatch)quantized.h) and NAX (quantized_nax.h,quantized_nax.metal) paths, plus quantize/dequantize kernelsmx.quantize(w, bits=1),mx.dequantize(...), andmx.quantized_matmul(...)documentationtest_quantize_dequantize,test_qmm, and a dedicatedtest_1bit_quantize_dequantizecovering round-trip accuracy, zero handling, and quantized matmul correctness. Full test suite passes (672 tests, 0 failures).dispatch_bitsdoes not include acase 1:path. The new 1-bit test is added tocuda_skip.py.Expected model-level performance (hypothetical, 8B parameter model, Apple M4 Pro 48 GB)
Based on the kernel-level benchmarks below, a hypothetical 8B parameter model at 1-bit would see roughly (varying by group size due to scale/bias metadata overhead):
We verified one scenario (group size 128, all weights quantized) and observed throughput in the ballpark of the estimates above. The primary purpose of this table is to give a sense of the runtime speed that 1-bit quantization enables. These are back-of-the-envelope numbers — actual end-to-end performance will vary depending on which layers are quantized, group size, attention overhead, and other non-quantized computation.
Kernel Corretness validation (KL divergence, 8B parameter model, WikiText-2)
To validate matmul kernel correctness, we compared two runs of the same 1-bit quantized model: one using the quantized matmul kernels (weights stay packed in 1-bit), and the other with the 1-bit weights dequantized to FP16 first and run through standard FP16 matmul. This is not a comparison between an FP16 model and its quantized version — both sides use identical weight values, so any divergence would indicate a kernel bug. Both the prompt processing (qmm) and token generation (qmv) paths were tested.
Prompt processing (qmm path) — 20 WikiText-2 chunks:
Token generation (qmv path) — 113 autoregressive steps (single-token qmv) across 5 prompts:
Both forward and reverse KL are near-zero, confirming the quantized kernels produce results consistent with the dequantized FP16 reference in both qmm and qmv code paths.
Changes
mlx/backend/cpu/quantized.cpp- 1-bit quantization logic andqmmdispatchmlx/backend/metal/kernels/quantized.h- Metal 1-bitload_vector,qdot,qdot_safe,qouter,dequantizemlx/backend/metal/kernels/quantized_nax.h- Same for NAX kernelsmlx/ops.cpp- Validation to acceptbits=1python/src/ops.cpp- Updated docstring tablepython/tests/test_quantized.py- Added 1-bit to existing tests + dedicated 1-bit testpython/tests/cuda_skip.py- Skip 1-bit test on CUDAbenchmarks/python/comparative/bench_mlx.py- Added 1-bit entries toquant_matmuldict; auto-quantizes weight from--sizeargsbenchmarks/python/comparative/compare.py- Addedquant_matmulbenchmark entries comparing 1/2/4/8-bit across qmv and qmm pathsNotes
The Metal
qmv_quad_implkernel has a minor edge case with 1-bit when the inner dimension is < 128. In practice this should never come up — virtually all models have dimensions well above 128.If all weights in a group are exactly 0, the affine 1-bit quantization computes
scale = eps(floored) andbias = 0, which dequantizes all values to near-zero (correct behavior).Kernel-level
quantized_matmulbenchmarks (Apple M4 Pro 48 GB, GPU, NAX path,group_size=128, 1000 calls, weight shape in parentheses):qmv path (M=1, single-token generation, memory-bandwidth bound):
qmm path (M=32, prompt processing, more compute-bound):
1-bit entries have been added to
benchmarks/python/comparative/bench_mlx.pyandcompare.py. To reproduce (from repo root):Questions for reviewers
NAX vs non-NAX testing: All benchmarks and the full test suite were run on macOS 26.2 (M4 Pro 48 GB), where NAX is active. The non-NAX path was partially validated by rebuilding with
-DCMAKE_CXX_FLAGS=-DMLX_METAL_NO_NAX— unit tests pass, but full benchmarking was only done on the NAX path. We only have access to an M4. Is the-DMLX_METAL_NO_NAXbuild flag sufficient to validate the non-NAX path, or would you recommend testing on actual older hardware (M1/M2/M3)?Test coverage: The full test suite passes (672 tests, 0 failures), including dedicated 1-bit tests for both symmetric and asymmetric weight round-trip accuracy, zero handling, and quantized matmul correctness across both qmm and qmv paths. Is there any additional testing you'd like to see before merging?
Future work
Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes