Skip to content

Fix U32/U16/I8/I16 weight loading for quantized models#5

Open
duncanita wants to merge 2 commits into
skryl:mainfrom
duncanita:fix-u32-weight-loading
Open

Fix U32/U16/I8/I16 weight loading for quantized models#5
duncanita wants to merge 2 commits into
skryl:mainfrom
duncanita:fix-u32-weight-loading

Conversation

@duncanita
Copy link
Copy Markdown

Summary

_tensor_to_mlx in lib/mlx_lm/weight_utils.rb declared U32 (and other integer dtypes) in the DTYPE_UNPACK constant but the if/elsif chain below had no U32 branch. Packed 4-bit quantized weights (stored as uint32 in mlx-community safetensors) silently fell through to the F32 fallback and were decoded as garbage floats, causing [dequantize] The matrix should be given as a uint32 on the first QuantizedEmbedding forward.

Reproduces on every 4-bit mlx-community model — tested against mlx-community/Llama-3.2-1B-Instruct-4bit:

bundle exec exe/mlx_lm generate \
  --model <cached_llama_3.2_1b_4bit> \
  --prompt 'The capital of France is'

Before: RuntimeError: [dequantize] The matrix should be given as a uint32.
After: Paris. The capital of France is Paris. ...

Commits

  1. Fix U32/U16/I8/I16 weight loading for quantized models — adds the missing branches to unblock quantized inference.
  2. Use DTYPE_UNPACK for table-driven dtype dispatch — the duplication between DTYPE_UNPACK and the if/elsif chain is what let the U32 branch go missing in the first place. Table-driven lookup keeps dtype knowledge in one place. F16/BF16 stay explicit (they use a uint16 + .view bit-cast). mx.send(sym) would have collided with MLX::Core#send (MLX's distributed-communication primitive takes 2..4 args), so the refactor uses mx.__send__(sym). F32 unknown-dtype fallback is preserved.

Side effect: F64 safetensors now correctly unpack as doubles (previously the chain had no F64 branch, so F64 bytes were misread through the F32 fallback).

Test plan

  • mlx_lm generate --model <mlx-community/Llama-3.2-1B-Instruct-4bit local path> --prompt 'The capital of France is' --max-tokens 12 --temp 0.0 produces coherent output.
  • Refactor preserves behavior — same output before and after commit 2.
  • CI (run by reviewer).

_tensor_to_mlx declared U32 and other integer dtypes in DTYPE_UNPACK
but never branched on them in the if/elsif chain. Packed 4-bit
quantized weights (stored as uint32 in mlx-community safetensors) fell
through to the F32 fallback and were decoded as garbage floats,
causing `[dequantize] The matrix should be given as a uint32` on the
first QuantizedEmbedding forward.

Reproduces on mlx-community/Llama-3.2-1B-Instruct-4bit and presumably
every 4-bit model.
The if/elsif chain in _tensor_to_mlx duplicated the DTYPE_UNPACK
constant declared at the top of the file — which is how the U16/U32/
I8/I16 branches went missing from the chain in the first place. Table-
driven lookup keeps the mapping in one place.

F16 and BF16 stay as explicit branches because they take a different
code path (uint16 stage + .view cast). Unknown-dtype F32 fallback is
preserved to match prior behavior.

Uses __send__ instead of send because MLX::Core defines a `send`
method (takes 2..4 args) that would shadow Object#send.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant