fix(gemma3n): gate bf16 fused decode path off M5 Neural Accelerator#61
Merged
Conversation
The fused Gemma3n bf16 decode path added in #60 (stacked AltUp predict/correct plus the `gemma3n_mlp_forward` bridge call) cuts Rust to C++ graph-construction overhead and improves decode on Apple Silicon without a Neural Accelerator (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max shows the opposite there: the fused path regressed gemma-3n-E4B-it bf16 same-process decode by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked-AltUp scheduling and the single fused MLP bridge call interact poorly with M5-class (Neural Accelerator) hardware, where the pre-fused per-op path schedules better. Gate both fused paths behind a new `use_fused_decode_path()` helper, which is true only when the hardware is not a Neural Accelerator part (`!(has_neural_accelerator && macos_supports_na)`), mirroring the same hardware predicate already used elsewhere in the core. `MLP::forward` now takes the fused `gemma3n_mlp_forward` bridge call only off NA hardware and otherwise runs the per-op bf16 path. `DecoderLayer::forward` dispatches to `forward_stacked` (the fused, stacked-AltUp layer) on non-NA hardware and to `forward_split` (the pre-fused path where `AltUp::predict`/`correct` return per-plane Vecs) on M5-class hardware. Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path, M5-class hardware avoids the regression. On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05) with coherent real-model output, holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio (39.05 / 35.65 = 1.10x) is recomputed against the M1 Ultra decode, and the note now records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host exercises the fused branch; the split branch is the unchanged pre-fused code).
inureyes
added a commit
that referenced
this pull request
May 21, 2026
…#62) On M5-class (Neural Accelerator) hardware the non-quantized Gemma3n decode GEMVs — the MLP gate/up projections and the tied LM head — stream their weights faster when MLX is handed an already-materialized transposed weight rather than transposing on the fly inside every decode step. This materializes those transposes once at load time on M5-class hardware only, leaving every other code path untouched. `MLP::gate_proj`/`up_proj` become a new `MlpInputProjection` enum. The `Standard` variant wraps `UnifiedLinear` exactly as before and is used for quantized layers and for all non-M5 hardware. The `Pretransposed` variant is selected only when the hardware is an M5-class part and the weight is not quantized: it transposes the projection weight, makes it contiguous, evaluates it at load time, and then `forward` is a plain `matmul` (plus optional bias) against the prepared weight. The tied LM head gets the same treatment through `pretranspose_large_m5_embedding`, which materializes the wide embedding transpose under the same M5-and-non-quantized guard; `Gemma3nLanguageModel` caches it in `embed_tokens_weight_t` and a new `lm_head` helper uses it when present, otherwise falling back to `embed_tokens.as_linear`. Quantized models keep their specialized 4-bit paths unchanged, and the split-path per-layer update now consumes the corrected planes by iterator instead of copying the first plane. This also adds an opt-in Metal GPU capture hook for profiling: when the process is launched with `MTL_CAPTURE_ENABLED=1` and `MLXCEL_CAPTURE_DECODE=<path>` is set, the generator captures exactly one warm decode token (at the second step, after the decode kernels are JIT-cached) to a `.gputrace` bundle and exits, so per-kernel timings are directly comparable with mlx-lm's `mx.metal.start_capture` script. The hook is inert unless both environment variables are set. This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and gets the pretransposed decode weights. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host takes the standard, non-pretransposed branch, so the dispatch and `lm_head` fallback are exercised end to end).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Gates the fused Gemma3n bf16 decode path introduced in #60 so it runs only on Apple Silicon without a Neural Accelerator. M5-class hardware falls back to the per-op split path, where #60's fused path regressed decode.
Why
#60's fused stacked-AltUp predict/correct path and the single
gemma3n_mlp_forwardbridge call cut graph-construction overhead and improve decode on non-NA Apple Silicon (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max showed the opposite there: gemma-3n-E4B-it bf16 same-process decode regressed by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked scheduling and the fused MLP bridge call interact poorly with M5-class (Neural Accelerator) GPUs, which schedule the pre-fused per-op path better.What changed
use_fused_decode_path()helper, true only when the hardware is not a Neural Accelerator part (!(has_neural_accelerator && macos_supports_na)), reusing the same hardware predicate already applied elsewhere in the core.MLP::forwardtakes the fusedgemma3n_mlp_forwardbridge call only off NA hardware; otherwise it runs the per-op bf16 path.DecoderLayer::forwarddispatches toforward_stacked(the fused, stacked-AltUp layer) on non-NA hardware and toforward_split(the pre-fused path whereAltUp::predict/correctreturn per-plane Vecs) on M5-class hardware.Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path and M5-class hardware avoids the regression.
Performance
On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05), holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio recomputed against the M1 Ultra decode (39.05 / 35.65 = 1.10x), and the note records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path.
Testing
cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warningsis clean.