Skip to content

perf(gemma3n): improve M5 decode bandwidth with pretransposed weights#62

Merged
inureyes merged 1 commit into
mainfrom
port/gemma3n-m5-decode-bandwidth
May 21, 2026
Merged

perf(gemma3n): improve M5 decode bandwidth with pretransposed weights#62
inureyes merged 1 commit into
mainfrom
port/gemma3n-m5-decode-bandwidth

Conversation

@inureyes
Copy link
Copy Markdown
Member

Summary

Improves Gemma3n bf16 decode bandwidth on M5-class (Neural Accelerator) hardware by materializing the transposed decode weights once at load time, and adds an opt-in Metal GPU capture hook for per-kernel profiling. Quantized models and all non-M5 hardware are unchanged.

What changed

  • Pretransposed MLP input projections. MLP::gate_proj/up_proj become an MlpInputProjection enum. Standard wraps UnifiedLinear exactly as before (used for quantized layers and all non-M5 hardware). Pretransposed is selected only on M5-class hardware for non-quantized weights: it transposes the weight, makes it contiguous, and evaluates it at load time, so forward becomes a plain matmul (plus optional bias) instead of an on-the-fly transpose every decode step.
  • Pretransposed tied LM head. pretranspose_large_m5_embedding materializes the wide tied-embedding transpose under the same M5-and-non-quantized guard; Gemma3nLanguageModel caches it in embed_tokens_weight_t, and a new lm_head helper uses it when present and otherwise falls back to embed_tokens.as_linear.
  • Split-path cleanup. The per-layer update in the split path consumes the corrected planes by iterator rather than copying the first plane.
  • Metal capture hook (opt-in). When the process runs with MTL_CAPTURE_ENABLED=1 and MLXCEL_CAPTURE_DECODE=<path> is set, the generator captures exactly one warm decode token (after decode kernels are JIT-cached) to a .gputrace bundle and exits, so per-kernel timings line up with mlx-lm's mx.metal.start_capture script. It is inert unless both environment variables are set.

This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and uses the pretransposed decode weights.

Testing

  • gemma3n helper unit tests pass.
  • cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warnings is clean.
  • A coherent generation on gemma3n-e4b-bf16. This M1 Ultra host takes the standard (non-pretransposed) branch and the lm_head fallback, so the new dispatch is exercised end to end; the M5 pretransposed branches are compile-verified here and were validated on M5 upstream.

Note

The internal commit also added a .claude/skills/mlxcel-gpu-profiling/SKILL.md development skill describing the profiling workflow. It is omitted here because this repository does not track .claude/ tooling.

On M5-class (Neural Accelerator) hardware the non-quantized Gemma3n decode GEMVs — the MLP gate/up projections and the tied LM head — stream their weights faster when MLX is handed an already-materialized transposed weight rather than transposing on the fly inside every decode step. This materializes those transposes once at load time on M5-class hardware only, leaving every other code path untouched.

`MLP::gate_proj`/`up_proj` become a new `MlpInputProjection` enum. The `Standard` variant wraps `UnifiedLinear` exactly as before and is used for quantized layers and for all non-M5 hardware. The `Pretransposed` variant is selected only when the hardware is an M5-class part and the weight is not quantized: it transposes the projection weight, makes it contiguous, evaluates it at load time, and then `forward` is a plain `matmul` (plus optional bias) against the prepared weight. The tied LM head gets the same treatment through `pretranspose_large_m5_embedding`, which materializes the wide embedding transpose under the same M5-and-non-quantized guard; `Gemma3nLanguageModel` caches it in `embed_tokens_weight_t` and a new `lm_head` helper uses it when present, otherwise falling back to `embed_tokens.as_linear`. Quantized models keep their specialized 4-bit paths unchanged, and the split-path per-layer update now consumes the corrected planes by iterator instead of copying the first plane.

This also adds an opt-in Metal GPU capture hook for profiling: when the process is launched with `MTL_CAPTURE_ENABLED=1` and `MLXCEL_CAPTURE_DECODE=<path>` is set, the generator captures exactly one warm decode token (at the second step, after the decode kernels are JIT-cached) to a `.gputrace` bundle and exits, so per-kernel timings are directly comparable with mlx-lm's `mx.metal.start_capture` script. The hook is inert unless both environment variables are set.

This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and gets the pretransposed decode weights. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host takes the standard, non-pretransposed branch, so the dispatch and `lm_head` fallback are exercised end to end).
@inureyes inureyes added type:performance Performance improvements priority:medium Medium priority area:models Model architectures, weights, loading, metadata labels May 21, 2026
@inureyes inureyes merged commit 4434b23 into main May 21, 2026
4 checks passed
@inureyes inureyes deleted the port/gemma3n-m5-decode-bandwidth branch May 21, 2026 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata priority:medium Medium priority type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant