perf(gemma3n): improve M5 decode bandwidth with pretransposed weights#62
Merged
Conversation
On M5-class (Neural Accelerator) hardware the non-quantized Gemma3n decode GEMVs — the MLP gate/up projections and the tied LM head — stream their weights faster when MLX is handed an already-materialized transposed weight rather than transposing on the fly inside every decode step. This materializes those transposes once at load time on M5-class hardware only, leaving every other code path untouched. `MLP::gate_proj`/`up_proj` become a new `MlpInputProjection` enum. The `Standard` variant wraps `UnifiedLinear` exactly as before and is used for quantized layers and for all non-M5 hardware. The `Pretransposed` variant is selected only when the hardware is an M5-class part and the weight is not quantized: it transposes the projection weight, makes it contiguous, evaluates it at load time, and then `forward` is a plain `matmul` (plus optional bias) against the prepared weight. The tied LM head gets the same treatment through `pretranspose_large_m5_embedding`, which materializes the wide embedding transpose under the same M5-and-non-quantized guard; `Gemma3nLanguageModel` caches it in `embed_tokens_weight_t` and a new `lm_head` helper uses it when present, otherwise falling back to `embed_tokens.as_linear`. Quantized models keep their specialized 4-bit paths unchanged, and the split-path per-layer update now consumes the corrected planes by iterator instead of copying the first plane. This also adds an opt-in Metal GPU capture hook for profiling: when the process is launched with `MTL_CAPTURE_ENABLED=1` and `MLXCEL_CAPTURE_DECODE=<path>` is set, the generator captures exactly one warm decode token (at the second step, after the decode kernels are JIT-cached) to a `.gputrace` bundle and exits, so per-kernel timings are directly comparable with mlx-lm's `mx.metal.start_capture` script. The hook is inert unless both environment variables are set. This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and gets the pretransposed decode weights. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host takes the standard, non-pretransposed branch, so the dispatch and `lm_head` fallback are exercised end to end).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Improves Gemma3n bf16 decode bandwidth on M5-class (Neural Accelerator) hardware by materializing the transposed decode weights once at load time, and adds an opt-in Metal GPU capture hook for per-kernel profiling. Quantized models and all non-M5 hardware are unchanged.
What changed
MLP::gate_proj/up_projbecome anMlpInputProjectionenum.StandardwrapsUnifiedLinearexactly as before (used for quantized layers and all non-M5 hardware).Pretransposedis selected only on M5-class hardware for non-quantized weights: it transposes the weight, makes it contiguous, and evaluates it at load time, soforwardbecomes a plainmatmul(plus optional bias) instead of an on-the-fly transpose every decode step.pretranspose_large_m5_embeddingmaterializes the wide tied-embedding transpose under the same M5-and-non-quantized guard;Gemma3nLanguageModelcaches it inembed_tokens_weight_t, and a newlm_headhelper uses it when present and otherwise falls back toembed_tokens.as_linear.MTL_CAPTURE_ENABLED=1andMLXCEL_CAPTURE_DECODE=<path>is set, the generator captures exactly one warm decode token (after decode kernels are JIT-cached) to a.gputracebundle and exits, so per-kernel timings line up with mlx-lm'smx.metal.start_capturescript. It is inert unless both environment variables are set.This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and uses the pretransposed decode weights.
Testing
cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warningsis clean.lm_headfallback, so the new dispatch is exercised end to end; the M5 pretransposed branches are compile-verified here and were validated on M5 upstream.Note
The internal commit also added a
.claude/skills/mlxcel-gpu-profiling/SKILL.mddevelopment skill describing the profiling workflow. It is omitted here because this repository does not track.claude/tooling.