Fix U32/U16/I8/I16 weight loading for quantized models#5
Open
duncanita wants to merge 2 commits into
Open
Conversation
_tensor_to_mlx declared U32 and other integer dtypes in DTYPE_UNPACK but never branched on them in the if/elsif chain. Packed 4-bit quantized weights (stored as uint32 in mlx-community safetensors) fell through to the F32 fallback and were decoded as garbage floats, causing `[dequantize] The matrix should be given as a uint32` on the first QuantizedEmbedding forward. Reproduces on mlx-community/Llama-3.2-1B-Instruct-4bit and presumably every 4-bit model.
The if/elsif chain in _tensor_to_mlx duplicated the DTYPE_UNPACK constant declared at the top of the file — which is how the U16/U32/ I8/I16 branches went missing from the chain in the first place. Table- driven lookup keeps the mapping in one place. F16 and BF16 stay as explicit branches because they take a different code path (uint16 stage + .view cast). Unknown-dtype F32 fallback is preserved to match prior behavior. Uses __send__ instead of send because MLX::Core defines a `send` method (takes 2..4 args) that would shadow Object#send.
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
_tensor_to_mlxinlib/mlx_lm/weight_utils.rbdeclaredU32(and other integer dtypes) in theDTYPE_UNPACKconstant but the if/elsif chain below had noU32branch. Packed 4-bit quantized weights (stored asuint32in mlx-community safetensors) silently fell through to the F32 fallback and were decoded as garbage floats, causing[dequantize] The matrix should be given as a uint32on the firstQuantizedEmbeddingforward.Reproduces on every 4-bit mlx-community model — tested against
mlx-community/Llama-3.2-1B-Instruct-4bit:Before:
RuntimeError: [dequantize] The matrix should be given as a uint32.After:
Paris. The capital of France is Paris. ...Commits
DTYPE_UNPACKand the if/elsif chain is what let the U32 branch go missing in the first place. Table-driven lookup keeps dtype knowledge in one place. F16/BF16 stay explicit (they use auint16 + .viewbit-cast).mx.send(sym)would have collided withMLX::Core#send(MLX's distributed-communication primitive takes 2..4 args), so the refactor usesmx.__send__(sym). F32 unknown-dtype fallback is preserved.Side effect:
F64safetensors now correctly unpack as doubles (previously the chain had no F64 branch, so F64 bytes were misread through the F32 fallback).Test plan
mlx_lm generate --model <mlx-community/Llama-3.2-1B-Instruct-4bit local path> --prompt 'The capital of France is' --max-tokens 12 --temp 0.0produces coherent output.