Skip to content

Commit f8cfc73

Browse files
mergennachinclaude
andauthored
Add MLX backend support for Gemma 4 31B (pytorch#19524)
Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same quantized checkpoint works for both CUDA and MLX — backend-specific packing happens at load time. Key changes: - MLX packer converts Int4Tensor → IntxUnpackedToInt8Tensor for MLX's quantized linear fusion - Source transforms replace PyTorch ops with mlx.rope, mlx.kv_cache_update, mlx.custom_sdpa for optimized Metal kernels - Proportional partial RoPE (full-attention layers) passes 1D frequencies to mlx.rope with dims=rotary_dim, fixing the C++ runtime to pass base=nullopt when freqs is provided - Single-method export with dynamic seq_len and host-side sampling - C++ runner supports both backends via #ifdef, using shared logits_to_token for MLX sampling - Last-logits-only optimization: lm_head always runs on last position only, removing the full-logits codepath entirely Nothing in the CUDA backend code itself. The CUDA-side changes are in the shared model/runner code: - model.py: forward() now always does last-logits-only and temperature is required (no None path). Affects both CUDA and MLX. - sampler.py: Removed temperature=None passthrough. - main.cpp: Unified temp_val clamping before the #ifdef. CUDA path behavior unchanged. - inference.py: Default temperature changed from 0.0 to 0.8 to match C++ runner default. On my 32GB RAM M1 macbook pro ``` (executorch_dev) mnachin@mnachin-mbp executorch % cmake-out/examples/models/gemma4_31b/gemma4_31b_runner --model_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/model.pte --tokenizer_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/tokenizer.json --prompt "Write a short joke about saving RAM." --max_new_tokens 128 I tokenizers:regex.cpp:27] Registering override fallback regex WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1779218557.174278 43844526 re2.cc:237] Error parsing '((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x0...': invalid UTF-8 I tokenizers:re2_regex.cpp:27] Re2 failed to compile regex: ((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x00\x00\\\<|\<tool_response\|\>|\<mask\>|\<\|\"\|\>|all\|\>j\x00\x00\\|\<channel\|\>|\<\|turn\>|\<turn\|\>|\<\|image\>|\<\|$ I tokenizers:regex_lookahead.cpp:27] Creating PCRE2 regex I tokenizers:pcre2_regex.cpp:48] PCRE2 UTF-8 validation failed at offset 27: UTF-8 error: byte 2 top bits not 0x80. Retrying without UTF flags. Loading model... Prompt tokens: 24 Why did the programmer get kicked out of the library? He kept trying to free the memory.<turn|> PyTorchObserver {"prefill_token_per_sec":7.56859,"decode_token_per_sec":2.09161,"prompt_tokens":24,"generated_tokens":20,"model_load_start_ms":1779218556804,"model_load_end_ms":1779218560048,"inference_start_ms":1779218560052,"inference_end_ms":1779218572785,"prompt_eval_end_ms":1779218563223,"first_token_ms":1779218563223,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 3d86cc7 commit f8cfc73

22 files changed

Lines changed: 1295 additions & 159 deletions

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ jobs:
150150
151151
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
152152
pip install gguf
153-
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="
153+
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ --ignore=examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v -o "addopts="
154154
155155
export-model-cuda-artifact:
156156
name: export-model-cuda-artifact

.github/workflows/mlx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
backends/mlx/test/test_passes.py \
6161
backends/mlx/test/test_pattern_utils.py \
6262
backends/mlx/test/test_partitioner.py \
63+
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
6364
-v
6465
echo "::endgroup::"
6566

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -127,6 +127,7 @@ help:
127127
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
128128
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
129129
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
130+
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
130131
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
131132
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
132133
@echo " clean - Clean build artifacts"
@@ -435,6 +436,15 @@ gemma4_31b-cuda:
435436
@echo "✓ Build complete!"
436437
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
437438

439+
gemma4_31b-mlx:
440+
@echo "==> Building and installing ExecuTorch with MLX..."
441+
cmake --workflow --preset mlx-release
442+
@echo "==> Building Gemma 4 31B runner with MLX..."
443+
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx
444+
@echo ""
445+
@echo "✓ Build complete!"
446+
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
447+
438448
qwen3_5_moe-metal:
439449
@echo "==> Building and installing ExecuTorch with Metal..."
440450
cmake --workflow --preset llm-release-metal

backends/mlx/custom_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,16 @@ def rope(
228228
# final angles: [1, 1, T, half]
229229
angles = (pos_range * inv_freq) * float(scale)
230230
else:
231-
# assume freqs is already per-position, just reshape to [1,1,T,half]
232-
angles = freqs.to(torch.float32).view(1, 1, T, half)
231+
if freqs.ndim == 1:
232+
# 1D raw frequencies: compute angles = positions * (1/freqs)
233+
inv_freq = (1.0 / freqs.to(torch.float32)).view(1, 1, 1, half)
234+
pos_range = torch.arange(
235+
pos, pos + T, device=x.device, dtype=torch.float32
236+
).view(1, 1, T, 1)
237+
angles = (pos_range * inv_freq) * float(scale)
238+
else:
239+
# 2D per-position angles: reshape to [1,1,T,half]
240+
angles = freqs.to(torch.float32).view(1, 1, T, half)
233241

234242
cos = angles.cos().to(x.dtype) # [1,1,T,half]
235243
sin = angles.sin().to(x.dtype) # [1,1,T,half]

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
242242
freqs_arr = st.const_tensor_ref(*n.freqs);
243243
}
244244

245+
// MLX requires exactly one of base or freqs — when freqs is provided,
246+
// base must be nullopt.
247+
std::optional<float> base =
248+
freqs_arr ? std::nullopt : std::optional<float>(n.base);
249+
245250
// MLX has two overloads: rope(..., int offset, ...) and rope(..., const
246251
// array& offset, ...) Call the appropriate one based on is_vid
247252
if (n.offset.is_vid) {
@@ -250,14 +255,14 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
250255
st.set_tensor(
251256
n.out,
252257
fast::rope(
253-
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
258+
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
254259
} else {
255260
// Tensor offset from Tid
256261
const array& offset = st.const_tensor_ref(n.offset.tid);
257262
st.set_tensor(
258263
n.out,
259264
fast::rope(
260-
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
265+
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
261266
}
262267
}
263268

backends/mlx/test/test_ops.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,82 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
18031803
return (q, k, pos_tensor)
18041804

18051805

1806+
class RopeCustomFreqsModel(nn.Module):
1807+
"""Model that applies RoPE with custom 1D frequencies (partial rotary)."""
1808+
1809+
def __init__(self, dims: int = 32, head_dim: int = 64):
1810+
super().__init__()
1811+
self.dims = dims
1812+
self.head_dim = head_dim
1813+
# Simulate proportional RoPE: compute freqs for rotary dims only
1814+
inv_freq = 1.0 / (
1815+
500000.0 ** (torch.arange(0, dims, 2, dtype=torch.float32) / head_dim)
1816+
)
1817+
self.register_buffer("freqs", 1.0 / inv_freq, persistent=False)
1818+
1819+
def forward(
1820+
self,
1821+
q: torch.Tensor,
1822+
k: torch.Tensor,
1823+
pos_tensor: torch.Tensor,
1824+
) -> Tuple[torch.Tensor, torch.Tensor]:
1825+
pos = pos_tensor.item()
1826+
q_rot = torch.ops.mlx.rope(q, self.dims, pos, False, 0.0, 1.0, self.freqs)
1827+
k_rot = torch.ops.mlx.rope(k, self.dims, pos, False, 0.0, 1.0, self.freqs)
1828+
return q_rot, k_rot
1829+
1830+
1831+
@register_test
1832+
class RopeCustomFreqsTest(OpTestCase):
1833+
"""Test RoPE with custom 1D frequencies (partial rotary, like Gemma 4)."""
1834+
1835+
name = "rope_custom_freqs"
1836+
rtol = 1e-4
1837+
atol = 1e-4
1838+
1839+
def __init__(
1840+
self,
1841+
batch_size: int = 1,
1842+
num_heads: int = 8,
1843+
seq_len: int = 4,
1844+
head_dim: int = 64,
1845+
dims: int = 32,
1846+
pos: int = 0,
1847+
):
1848+
self.batch_size = batch_size
1849+
self.num_heads = num_heads
1850+
self.seq_len = seq_len
1851+
self.head_dim = head_dim
1852+
self.dims = dims
1853+
self.pos = pos
1854+
self.name = "rope_custom_freqs"
1855+
1856+
@classmethod
1857+
def get_test_configs(cls) -> List["RopeCustomFreqsTest"]:
1858+
configs = [
1859+
cls(),
1860+
cls(pos=10),
1861+
cls(head_dim=128, dims=64),
1862+
]
1863+
for cfg in configs:
1864+
parts = ["rope_custom_freqs"]
1865+
if cfg.pos > 0:
1866+
parts.append(f"pos{cfg.pos}")
1867+
if cfg.head_dim != 64:
1868+
parts.append(f"hd{cfg.head_dim}")
1869+
cfg.name = "_".join(parts)
1870+
return configs
1871+
1872+
def create_model(self) -> nn.Module:
1873+
return RopeCustomFreqsModel(dims=self.dims, head_dim=self.head_dim)
1874+
1875+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
1876+
q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
1877+
k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
1878+
pos_tensor = torch.tensor(self.pos, dtype=torch.int64)
1879+
return (q, k, pos_tensor)
1880+
1881+
18061882
from executorch.backends.mlx.llm.cache import KVCache
18071883

18081884

examples/models/gemma4_31b/CMakeLists.txt

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,17 @@ list(
4242
extension_flat_tensor
4343
)
4444

45-
# CUDA backend (the only supported backend for this example for now)
45+
# Backend: CUDA or MLX (exactly one required)
4646
if(EXECUTORCH_BUILD_CUDA)
4747
find_package(CUDAToolkit REQUIRED)
4848
list(APPEND link_libraries aoti_cuda_backend)
4949
executorch_target_link_options_shared_lib(aoti_cuda_backend)
5050
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
51+
elseif(TARGET mlxdelegate)
52+
list(APPEND link_libraries mlxdelegate mlx)
53+
executorch_target_link_options_shared_lib(mlxdelegate)
5154
else()
52-
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON")
55+
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON")
5356
endif()
5457

5558
# Tokenizer (HuggingFace tokenizer.json)
@@ -63,5 +66,11 @@ target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})
6366

6467
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
6568
target_link_options_gc_sections(gemma4_31b_runner)
66-
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
69+
if(NOT APPLE AND NOT MSVC)
70+
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
71+
endif()
72+
endif()
73+
74+
if(TARGET mlxdelegate)
75+
executorch_target_copy_mlx_metallib(gemma4_31b_runner)
6776
endif()

examples/models/gemma4_31b/CMakePresets.json

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
"string": "${hostSystemName}",
2424
"list": ["Linux", "Windows"]
2525
}
26+
},
27+
{
28+
"name": "gemma4-31b-mlx",
29+
"displayName": "Gemma 4 31B runner (MLX)",
30+
"inherits": ["gemma4-31b-base"],
31+
"cacheVariables": {},
32+
"condition": {
33+
"type": "equals",
34+
"lhs": "${hostSystemName}",
35+
"rhs": "Darwin"
36+
}
2637
}
2738
],
2839
"buildPresets": [
@@ -31,6 +42,12 @@
3142
"displayName": "Build Gemma 4 31B runner (CUDA)",
3243
"configurePreset": "gemma4-31b-cuda",
3344
"targets": ["gemma4_31b_runner"]
45+
},
46+
{
47+
"name": "gemma4-31b-mlx",
48+
"displayName": "Build Gemma 4 31B runner (MLX)",
49+
"configurePreset": "gemma4-31b-mlx",
50+
"targets": ["gemma4_31b_runner"]
3451
}
3552
],
3653
"workflowPresets": [
@@ -47,6 +64,20 @@
4764
"name": "gemma4-31b-cuda"
4865
}
4966
]
67+
},
68+
{
69+
"name": "gemma4-31b-mlx",
70+
"displayName": "Configure and build Gemma 4 31B runner (MLX)",
71+
"steps": [
72+
{
73+
"type": "configure",
74+
"name": "gemma4-31b-mlx"
75+
},
76+
{
77+
"type": "build",
78+
"name": "gemma4-31b-mlx"
79+
}
80+
]
5081
}
5182
]
5283
}

examples/models/gemma4_31b/README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Gemma 4 31B-IT
22

33
Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8
4-
weight quantization. Currently supports the CUDA backend.
4+
weight quantization. Supports CUDA and MLX (Apple Silicon) backends.
55

66
For architecture and design notes see [model.md](model.md).
77

@@ -67,6 +67,8 @@ recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into
6767

6868
## Export to ExecuTorch
6969

70+
### CUDA
71+
7072
```bash
7173
python examples/models/gemma4_31b/export.py \
7274
--prequantized ./gemma4_31b_int4 \
@@ -75,7 +77,20 @@ python examples/models/gemma4_31b/export.py \
7577
--backend cuda
7678
```
7779

78-
Writes `model.pte` and `model.ptd` into `--output-dir`.
80+
### MLX (Apple Silicon)
81+
82+
```bash
83+
python examples/models/gemma4_31b/export.py \
84+
--prequantized ./gemma4_31b_int4 \
85+
--output-dir ./gemma4_31b_exports_mlx \
86+
--max-seq-len 4096 \
87+
--backend mlx
88+
```
89+
90+
The same quantized checkpoint works for both backends. MLX exports a single
91+
method with dynamic sequence length and host-side sampling.
92+
93+
Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`.
7994

8095
## Eager inference
8196

@@ -105,7 +120,8 @@ model produces sensible text.
105120
## Build the runner
106121

107122
```bash
108-
make gemma4_31b-cuda
123+
make gemma4_31b-cuda # Linux — CUDA backend
124+
make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon)
109125
```
110126

111127
The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.

0 commit comments

Comments
 (0)