Commit f8cfc73
Add MLX backend support for Gemma 4 31B (pytorch#19524)
Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same
quantized checkpoint works for both CUDA and MLX — backend-specific
packing happens at load time.
Key changes:
- MLX packer converts Int4Tensor → IntxUnpackedToInt8Tensor for MLX's
quantized linear fusion
- Source transforms replace PyTorch ops with mlx.rope,
mlx.kv_cache_update, mlx.custom_sdpa for optimized Metal kernels
- Proportional partial RoPE (full-attention layers) passes 1D
frequencies to mlx.rope with dims=rotary_dim, fixing the C++ runtime to
pass base=nullopt when freqs is provided
- Single-method export with dynamic seq_len and host-side sampling
- C++ runner supports both backends via #ifdef, using shared
logits_to_token for MLX sampling
- Last-logits-only optimization: lm_head always runs on last position
only, removing the full-logits codepath entirely
Nothing in the CUDA backend code itself. The CUDA-side changes are in
the shared model/runner code:
- model.py: forward() now always does last-logits-only and temperature
is required (no None path). Affects both CUDA and MLX.
- sampler.py: Removed temperature=None passthrough.
- main.cpp: Unified temp_val clamping before the #ifdef. CUDA path
behavior unchanged.
- inference.py: Default temperature changed from 0.0 to 0.8 to match C++
runner default.
On my 32GB RAM M1 macbook pro
```
(executorch_dev) mnachin@mnachin-mbp executorch % cmake-out/examples/models/gemma4_31b/gemma4_31b_runner --model_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/model.pte --tokenizer_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/tokenizer.json --prompt "Write a short joke about saving RAM." --max_new_tokens 128
I tokenizers:regex.cpp:27] Registering override fallback regex
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1779218557.174278 43844526 re2.cc:237] Error parsing '((\<pad\>|ool\|\>1\x00\x00\
�\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x0...': invalid UTF-8
I tokenizers:re2_regex.cpp:27] Re2 failed to compile regex: ((\<pad\>|ool\|\>1\x00\x00\
�\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x00\x00\\\<|\<tool_response\|\>|\<mask\>|\<\|\"\|\>|all\|\>j\x00\x00\\|\<channel\|\>|\<\|turn\>|\<turn\|\>|\<\|image\>|\<\|$
I tokenizers:regex_lookahead.cpp:27] Creating PCRE2 regex
I tokenizers:pcre2_regex.cpp:48] PCRE2 UTF-8 validation failed at offset 27: UTF-8 error: byte 2 top bits not 0x80. Retrying without UTF flags.
Loading model...
Prompt tokens: 24
Why did the programmer get kicked out of the library?
He kept trying to free the memory.<turn|>
PyTorchObserver {"prefill_token_per_sec":7.56859,"decode_token_per_sec":2.09161,"prompt_tokens":24,"generated_tokens":20,"model_load_start_ms":1779218556804,"model_load_end_ms":1779218560048,"inference_start_ms":1779218560052,"inference_end_ms":1779218572785,"prompt_eval_end_ms":1779218563223,"first_token_ms":1779218563223,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
```
---------
Co-authored-by: Claude <noreply@anthropic.com>1 parent 3d86cc7 commit f8cfc73
22 files changed
Lines changed: 1295 additions & 159 deletions
File tree
- .github/workflows
- backends/mlx
- runtime
- test
- examples/models/gemma4_31b
- quant
- tests
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
150 | 150 | | |
151 | 151 | | |
152 | 152 | | |
153 | | - | |
| 153 | + | |
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| 63 | + | |
63 | 64 | | |
64 | 65 | | |
65 | 66 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
| 94 | + | |
95 | 95 | | |
96 | 96 | | |
97 | 97 | | |
| |||
127 | 127 | | |
128 | 128 | | |
129 | 129 | | |
| 130 | + | |
130 | 131 | | |
131 | 132 | | |
132 | 133 | | |
| |||
435 | 436 | | |
436 | 437 | | |
437 | 438 | | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
438 | 448 | | |
439 | 449 | | |
440 | 450 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
228 | 228 | | |
229 | 229 | | |
230 | 230 | | |
231 | | - | |
232 | | - | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
233 | 241 | | |
234 | 242 | | |
235 | 243 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
242 | 242 | | |
243 | 243 | | |
244 | 244 | | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
245 | 250 | | |
246 | 251 | | |
247 | 252 | | |
| |||
250 | 255 | | |
251 | 256 | | |
252 | 257 | | |
253 | | - | |
| 258 | + | |
254 | 259 | | |
255 | 260 | | |
256 | 261 | | |
257 | 262 | | |
258 | 263 | | |
259 | 264 | | |
260 | | - | |
| 265 | + | |
261 | 266 | | |
262 | 267 | | |
263 | 268 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1803 | 1803 | | |
1804 | 1804 | | |
1805 | 1805 | | |
| 1806 | + | |
| 1807 | + | |
| 1808 | + | |
| 1809 | + | |
| 1810 | + | |
| 1811 | + | |
| 1812 | + | |
| 1813 | + | |
| 1814 | + | |
| 1815 | + | |
| 1816 | + | |
| 1817 | + | |
| 1818 | + | |
| 1819 | + | |
| 1820 | + | |
| 1821 | + | |
| 1822 | + | |
| 1823 | + | |
| 1824 | + | |
| 1825 | + | |
| 1826 | + | |
| 1827 | + | |
| 1828 | + | |
| 1829 | + | |
| 1830 | + | |
| 1831 | + | |
| 1832 | + | |
| 1833 | + | |
| 1834 | + | |
| 1835 | + | |
| 1836 | + | |
| 1837 | + | |
| 1838 | + | |
| 1839 | + | |
| 1840 | + | |
| 1841 | + | |
| 1842 | + | |
| 1843 | + | |
| 1844 | + | |
| 1845 | + | |
| 1846 | + | |
| 1847 | + | |
| 1848 | + | |
| 1849 | + | |
| 1850 | + | |
| 1851 | + | |
| 1852 | + | |
| 1853 | + | |
| 1854 | + | |
| 1855 | + | |
| 1856 | + | |
| 1857 | + | |
| 1858 | + | |
| 1859 | + | |
| 1860 | + | |
| 1861 | + | |
| 1862 | + | |
| 1863 | + | |
| 1864 | + | |
| 1865 | + | |
| 1866 | + | |
| 1867 | + | |
| 1868 | + | |
| 1869 | + | |
| 1870 | + | |
| 1871 | + | |
| 1872 | + | |
| 1873 | + | |
| 1874 | + | |
| 1875 | + | |
| 1876 | + | |
| 1877 | + | |
| 1878 | + | |
| 1879 | + | |
| 1880 | + | |
| 1881 | + | |
1806 | 1882 | | |
1807 | 1883 | | |
1808 | 1884 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
51 | 54 | | |
52 | | - | |
| 55 | + | |
53 | 56 | | |
54 | 57 | | |
55 | 58 | | |
| |||
63 | 66 | | |
64 | 67 | | |
65 | 68 | | |
66 | | - | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
67 | 76 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
26 | 37 | | |
27 | 38 | | |
28 | 39 | | |
| |||
31 | 42 | | |
32 | 43 | | |
33 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
34 | 51 | | |
35 | 52 | | |
36 | 53 | | |
| |||
47 | 64 | | |
48 | 65 | | |
49 | 66 | | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
50 | 81 | | |
51 | 82 | | |
52 | 83 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
| 4 | + | |
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| |||
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| 70 | + | |
| 71 | + | |
70 | 72 | | |
71 | 73 | | |
72 | 74 | | |
| |||
75 | 77 | | |
76 | 78 | | |
77 | 79 | | |
78 | | - | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
79 | 94 | | |
80 | 95 | | |
81 | 96 | | |
| |||
105 | 120 | | |
106 | 121 | | |
107 | 122 | | |
108 | | - | |
| 123 | + | |
| 124 | + | |
109 | 125 | | |
110 | 126 | | |
111 | 127 | | |
| |||
0 commit comments