Skip to content

perf(gemma3n): reduce bf16 decode AltUp/MLP graph overhead#60

Merged
inureyes merged 1 commit into
mainfrom
port/gemma3n-bf16-decode-overhead
May 21, 2026
Merged

perf(gemma3n): reduce bf16 decode AltUp/MLP graph overhead#60
inureyes merged 1 commit into
mainfrom
port/gemma3n-bf16-decode-overhead

Conversation

@inureyes
Copy link
Copy Markdown
Member

Summary

Reduces graph-construction overhead on the Gemma3n bf16 decode hot path by keeping the per-layer AltUp update as a single stacked tensor and collapsing the dense bf16 MLP into one C++ bridge call.

What changed

  • Stacked AltUp. AltUp::predict/correct previously cast, stacked, projected, and then sliced the four AltUp planes back into a Vec<UniquePtr<MlxArray>> every layer — and correct re-stacked the same arrays just to add the correction. New predict_stacked/correct_stacked keep the prediction as one [altup, B, L, hidden] graph island, cast once after the stack (matching mlx-lm's x.astype(mx.float32) scheduling) instead of once per plane, and add the correction onto the existing stacked tensor. The Vec-returning predict/correct wrappers are kept as thin shims that delegate and split, so parity tests and external callers are unchanged.
  • Stacked decoder layer. DecoderLayer::forward consumes the stacked tensor end to end and only slices the active plane when needed. Three helpers back this: slice_altup_plane, split_altup_planes, and split_altup_after_per_layer_update (the last folds mlx-lm's corrected_predictions[1:] += first_prediction update into the split).
  • Fused bf16 MLP. The non-quantized bf16 language MLP now runs through one gemma3n_mlp_forward C++ bridge call (cast to bf16, gate/up, gelu_approx or gelu_topk, down, cast back) instead of four Rust ops. Matmuls deliberately stay outside mx::compile, like the existing compiled_swiglu_mlp_forward_fp16/compiled_gelu_mlp_forward_fp16 helpers, while the element-wise activation reuses the cached compiled kernels. Quantized weights keep the existing op-at-a-time path.

Performance

On the Mac Studio M1 Ultra, gemma-3n-E4B-it bf16 text decode improves from 34.41 to 35.65 tok/s (about +3.6%), crossing 90% mlx-lm parity (88% to 91%). The M1 Ultra benchmark page is updated to the new decode value with the parity count bumped to match.

Testing

  • New unit tests slice_altup_plane_selects_stacked_prediction_plane and split_altup_after_per_layer_update_preserves_plane_zero_and_updates_tail pass alongside the existing gemma3n helper tests.
  • cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warnings is clean (the C++ bridge compiles and links).
  • A short real-model generation on gemma3n-e4b-bf16 produces coherent output through the new stacked + fused path.

The bf16 decode hot path built the per-layer AltUp update as a sequence of `Vec<UniquePtr<MlxArray>>` operations: `predict` cast each of the four AltUp planes to f32 individually, stacked them, ran the projection, then sliced the result back into four arrays; `correct` re-stacked those same four arrays only to add the correction and slice them apart again; and the dense MLP issued four separate Rust to C++ bridge calls (gate, up, activation, down). Every slice/stack/cast boundary is a fresh graph node, so a single decoder layer churned through far more graph construction than the underlying math requires.

Keep the AltUp prediction as one `[altup, B, L, hidden]` graph island for the whole layer. New `predict_stacked` and `correct_stacked` methods operate on the stacked tensor directly, casting once after the stack (matching mlx-lm's `x.astype(mx.float32)` scheduling) instead of once per plane, and `correct_stacked` adds the correction onto the existing stacked predictions rather than rebuilding them. The public `predict`/`correct` Vec-returning wrappers are preserved as thin shims that delegate to the stacked path and split, so the parity tests and any external callers are unaffected. `DecoderLayer::forward` now consumes the stacked tensor end to end and only slices the active plane when it actually needs it. Three helpers back this: `slice_altup_plane`, `split_altup_planes`, and `split_altup_after_per_layer_update` (which folds mlx-lm's `corrected_predictions[1:] += first_prediction` update into the split).

The non-quantized bf16 language MLP now runs through a single `gemma3n_mlp_forward` C++ bridge call (cast input to bf16, gate/up, gelu_approx or gelu_topk, down, cast back to bf16) instead of four Rust-side ops. Matmuls stay outside `mx::compile` for the same reason as the existing `compiled_swiglu_mlp_forward_fp16`/`compiled_gelu_mlp_forward_fp16` helpers — compiled matmul+transpose graphs can reuse the wrong per-layer constants — while the element-wise activation reuses the cached compiled kernels. Quantized weights fall back to the existing op-at-a-time path unchanged.

On the Mac Studio M1 Ultra, this lifts gemma-3n-E4B-it bf16 text decode from 34.41 to 35.65 tok/s (about +3.6%), crossing 90% mlx-lm parity (88% to 91%); the M1 Ultra benchmark page is updated to the new decode value and the parity count is bumped accordingly. Verified with the new `slice_altup_plane`/`split_altup_after_per_layer_update` unit tests, a clean clippy build of the C++ bridge, and a coherent real-model generation on gemma3n-e4b-bf16.
@inureyes inureyes added type:performance Performance improvements priority:medium Medium priority area:models Model architectures, weights, loading, metadata labels May 21, 2026
@inureyes inureyes merged commit 0e4faff into main May 21, 2026
4 checks passed
@inureyes inureyes deleted the port/gemma3n-bf16-decode-overhead branch May 21, 2026 09:08
inureyes added a commit that referenced this pull request May 21, 2026
)

The fused Gemma3n bf16 decode path added in #60 (stacked AltUp predict/correct plus the `gemma3n_mlp_forward` bridge call) cuts Rust to C++ graph-construction overhead and improves decode on Apple Silicon without a Neural Accelerator (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max shows the opposite there: the fused path regressed gemma-3n-E4B-it bf16 same-process decode by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked-AltUp scheduling and the single fused MLP bridge call interact poorly with M5-class (Neural Accelerator) hardware, where the pre-fused per-op path schedules better.

Gate both fused paths behind a new `use_fused_decode_path()` helper, which is true only when the hardware is not a Neural Accelerator part (`!(has_neural_accelerator && macos_supports_na)`), mirroring the same hardware predicate already used elsewhere in the core. `MLP::forward` now takes the fused `gemma3n_mlp_forward` bridge call only off NA hardware and otherwise runs the per-op bf16 path. `DecoderLayer::forward` dispatches to `forward_stacked` (the fused, stacked-AltUp layer) on non-NA hardware and to `forward_split` (the pre-fused path where `AltUp::predict`/`correct` return per-plane Vecs) on M5-class hardware. Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path, M5-class hardware avoids the regression.

On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05) with coherent real-model output, holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio (39.05 / 35.65 = 1.10x) is recomputed against the M1 Ultra decode, and the note now records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host exercises the fused branch; the split branch is the unchanged pre-fused code).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata priority:medium Medium priority type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant