Skip to content

fix(gemma3n): gate bf16 fused decode path off M5 Neural Accelerator#61

Merged
inureyes merged 1 commit into
mainfrom
port/gemma3n-fused-decode-na-gate
May 21, 2026
Merged

fix(gemma3n): gate bf16 fused decode path off M5 Neural Accelerator#61
inureyes merged 1 commit into
mainfrom
port/gemma3n-fused-decode-na-gate

Conversation

@inureyes
Copy link
Copy Markdown
Member

Summary

Gates the fused Gemma3n bf16 decode path introduced in #60 so it runs only on Apple Silicon without a Neural Accelerator. M5-class hardware falls back to the per-op split path, where #60's fused path regressed decode.

Why

#60's fused stacked-AltUp predict/correct path and the single gemma3n_mlp_forward bridge call cut graph-construction overhead and improve decode on non-NA Apple Silicon (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max showed the opposite there: gemma-3n-E4B-it bf16 same-process decode regressed by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked scheduling and the fused MLP bridge call interact poorly with M5-class (Neural Accelerator) GPUs, which schedule the pre-fused per-op path better.

What changed

  • New use_fused_decode_path() helper, true only when the hardware is not a Neural Accelerator part (!(has_neural_accelerator && macos_supports_na)), reusing the same hardware predicate already applied elsewhere in the core.
  • MLP::forward takes the fused gemma3n_mlp_forward bridge call only off NA hardware; otherwise it runs the per-op bf16 path.
  • DecoderLayer::forward dispatches to forward_stacked (the fused, stacked-AltUp layer) on non-NA hardware and to forward_split (the pre-fused path where AltUp::predict/correct return per-plane Vecs) on M5-class hardware.

Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path and M5-class hardware avoids the regression.

Performance

On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05), holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio recomputed against the M1 Ultra decode (39.05 / 35.65 = 1.10x), and the note records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path.

Testing

  • gemma3n helper unit tests pass.
  • cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warnings is clean.
  • A short real-model generation on gemma3n-e4b-bf16 is coherent (this M1 Ultra host exercises the fused branch; the split branch is the unchanged pre-fused code).

The fused Gemma3n bf16 decode path added in #60 (stacked AltUp predict/correct plus the `gemma3n_mlp_forward` bridge call) cuts Rust to C++ graph-construction overhead and improves decode on Apple Silicon without a Neural Accelerator (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max shows the opposite there: the fused path regressed gemma-3n-E4B-it bf16 same-process decode by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked-AltUp scheduling and the single fused MLP bridge call interact poorly with M5-class (Neural Accelerator) hardware, where the pre-fused per-op path schedules better.

Gate both fused paths behind a new `use_fused_decode_path()` helper, which is true only when the hardware is not a Neural Accelerator part (`!(has_neural_accelerator && macos_supports_na)`), mirroring the same hardware predicate already used elsewhere in the core. `MLP::forward` now takes the fused `gemma3n_mlp_forward` bridge call only off NA hardware and otherwise runs the per-op bf16 path. `DecoderLayer::forward` dispatches to `forward_stacked` (the fused, stacked-AltUp layer) on non-NA hardware and to `forward_split` (the pre-fused path where `AltUp::predict`/`correct` return per-plane Vecs) on M5-class hardware. Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path, M5-class hardware avoids the regression.

On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05) with coherent real-model output, holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio (39.05 / 35.65 = 1.10x) is recomputed against the M1 Ultra decode, and the note now records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host exercises the fused branch; the split branch is the unchanged pre-fused code).
@inureyes inureyes added type:bug Bug fixes, error corrections, or issue resolutions priority:medium Medium priority area:models Model architectures, weights, loading, metadata labels May 21, 2026
@inureyes inureyes merged commit 206a2d5 into main May 21, 2026
4 checks passed
@inureyes inureyes deleted the port/gemma3n-fused-decode-na-gate branch May 21, 2026 09:16
inureyes added a commit that referenced this pull request May 21, 2026
…#62)

On M5-class (Neural Accelerator) hardware the non-quantized Gemma3n decode GEMVs — the MLP gate/up projections and the tied LM head — stream their weights faster when MLX is handed an already-materialized transposed weight rather than transposing on the fly inside every decode step. This materializes those transposes once at load time on M5-class hardware only, leaving every other code path untouched.

`MLP::gate_proj`/`up_proj` become a new `MlpInputProjection` enum. The `Standard` variant wraps `UnifiedLinear` exactly as before and is used for quantized layers and for all non-M5 hardware. The `Pretransposed` variant is selected only when the hardware is an M5-class part and the weight is not quantized: it transposes the projection weight, makes it contiguous, evaluates it at load time, and then `forward` is a plain `matmul` (plus optional bias) against the prepared weight. The tied LM head gets the same treatment through `pretranspose_large_m5_embedding`, which materializes the wide embedding transpose under the same M5-and-non-quantized guard; `Gemma3nLanguageModel` caches it in `embed_tokens_weight_t` and a new `lm_head` helper uses it when present, otherwise falling back to `embed_tokens.as_linear`. Quantized models keep their specialized 4-bit paths unchanged, and the split-path per-layer update now consumes the corrected planes by iterator instead of copying the first plane.

This also adds an opt-in Metal GPU capture hook for profiling: when the process is launched with `MTL_CAPTURE_ENABLED=1` and `MLXCEL_CAPTURE_DECODE=<path>` is set, the generator captures exactly one warm decode token (at the second step, after the decode kernels are JIT-cached) to a `.gputrace` bundle and exits, so per-kernel timings are directly comparable with mlx-lm's `mx.metal.start_capture` script. The hook is inert unless both environment variables are set.

This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and gets the pretransposed decode weights. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host takes the standard, non-pretransposed branch, so the dispatch and `lm_head` fallback are exercised end to end).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata priority:medium Medium priority type:bug Bug fixes, error corrections, or issue resolutions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant