Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/benchmark_results/model_tests_m5max.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Compatibility and performance testing for mlxcel models on **MacBook Pro M5 Max
| gemma3 (4B) | gemma-3-4b-it-4bit | ✅ | 819.97 | 182.16 | **1.83x** | 81 tokens; full-budget raw prompt 183.77 tok/s |
| gemma3n (E2B) | gemma-3n-E2B-it-4bit | ✅ | 812.25 | 158.71 | **2.06x** | 71 tokens |
| gemma3n (E4B) | gemma-3n-E4B-it-4bit | ✅ | 601.09 | 110.24 | **1.83x** | 71 tokens |
| gemma3n (E4B bf16) | gemma-3n-E4B-it (bf16) | ✅ | 348.30 | 39.19 | **1.14x** | 72 tokens; issue #716; Gemma3n language MLP bf16 preserved, other bf16 materialized as f16; 78% of mlx-lm decode |
| gemma3n (E4B bf16) | gemma-3n-E4B-it (bf16) | ✅ | 348.30 | 39.05 | 1.10x | Gemma3n language MLP bf16 preserved, other bf16 materialized as f16; M5 (Neural Accelerator) uses the split decode path while other Apple Silicon uses the fused path; ~80% of mlx-lm decode |
| gemma4 (26B MoE) | gemma-4-26b-a4b-it-4bit | ✅ | 539.57 | 137.12 | **1.87x** | 37 tokens |
| gemma4 (31B) | gemma-4-31b-4bit | ✅ | 71.51 | 28.59 | **1.42x** | 100 tokens |
| gemma4 (31B IT) | gemma-4-31b-it-4bit | ✅ | 144.09 | 27.34 | **1.43x** | 25 tokens |
Expand Down Expand Up @@ -299,7 +299,7 @@ snapshot.
| gemma3-4b-4bit | 182.16 | 181.66 | **100%** |
| gemma3n-e2b-4bit | 158.71 | FAIL | - |
| gemma3n-e4b-4bit | 110.24 | FAIL | - |
| gemma3n-e4b-bf16 | 39.19 | 48.72 | 80% |
| gemma3n-e4b-bf16 | 39.05 | 48.72 | 80% |
| glm4-flash-4bit | 104.30 | 104.03 | **100%** |
| gpt-oss-120b-4bit | 114.03 | 110.35 | **103%** |
| gpt-oss-20b-mxfp4 | 172.33 | 168.33 | **102%** |
Expand Down
111 changes: 106 additions & 5 deletions src/models/gemma3n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,18 @@ pub struct MLP {
pub std_multiplier: f32,
}

/// #60 introduced a fused Gemma3n decode path (stacked AltUp predict/
/// correct plus the `gemma3n_mlp_forward` bridge call) that cuts Rust↔C++
/// graph-construction overhead. It improves decode on Apple Silicon without a
/// Neural Accelerator (M1 Ultra: +3.6%) but regresses M5-class hardware
/// (about -6.3% on M5 Max `gemma3n-e4b-bf16`). Use the fused path only off NA
/// hardware; M5-class GPUs fall back to the per-op split path.
#[inline]
fn use_fused_decode_path() -> bool {
let hw = mlxcel_core::hardware::get_hardware();
!(hw.has_neural_accelerator && hw.macos_supports_na)
}

impl MLP {
/// Apply gelu_topk activation with sparsity.
/// Uses a compiled fused kernel matching Python's @mx.compile gelu_topk.
Expand All @@ -686,11 +698,17 @@ impl MLP {
}

pub fn forward(&self, x: &MlxArray) -> UniquePtr<MlxArray> {
if let (Some(gate), Some(up), Some(down)) = (
self.gate_proj.regular_weight(),
self.up_proj.regular_weight(),
self.down_proj.regular_weight(),
) {
// The fused MLP bridge call (`gemma3n_mlp_forward`, added in #60) cuts
// decode graph overhead on non-NA Apple Silicon but regresses M5-class
// hardware. Gate it off NA hardware so M5 uses the per-op bf16 path
// below.
if use_fused_decode_path()
&& let (Some(gate), Some(up), Some(down)) = (
self.gate_proj.regular_weight(),
self.up_proj.regular_weight(),
self.down_proj.regular_weight(),
)
{
let gate_bias_ptr = gate
.bias
.as_ref()
Expand Down Expand Up @@ -813,6 +831,26 @@ impl DecoderLayer {
mask: Option<&MlxArray>,
cache: &mut KVCache,
per_layer_input: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
// The stacked AltUp path (added in #60) speeds up decode on non-NA
// Apple Silicon but regresses M5-class hardware. Dispatch to the split
// path on NA hardware; keep the stacked path elsewhere.
if use_fused_decode_path() {
self.forward_stacked(x, mask, cache, per_layer_input)
} else {
self.forward_split(x, mask, cache, per_layer_input)
}
}

/// Fused-path layer forward (added in #60): keeps AltUp predictions stacked
/// from predict through correct, slicing only the active plane. Faster on
/// Apple Silicon without a Neural Accelerator; selected by `forward`.
fn forward_stacked(
&self,
x: &[UniquePtr<MlxArray>],
mask: Option<&MlxArray>,
cache: &mut KVCache,
per_layer_input: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
// AltUp predict. Keep the prediction tensor stacked until correction
// so decode avoids slicing four planes and stacking them again within
Expand Down Expand Up @@ -869,6 +907,69 @@ impl DecoderLayer {
)
}

/// Split-path layer forward (the path before #60's fused path):
/// `AltUp::predict`/`correct` return per-plane Vecs. Used on M5-class
/// (Neural Accelerator) hardware where the stacked path regresses decode.
fn forward_split(
&self,
x: &[UniquePtr<MlxArray>],
mask: Option<&MlxArray>,
cache: &mut KVCache,
per_layer_input: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
// AltUp predict
let predictions = self.altup.predict(x);
let active_prediction = &predictions[self.altup_active_idx];

// Input layernorm
let active_normed = self.input_layernorm.forward(active_prediction);

// LAUREL
let laurel_output = self.laurel.forward(&active_normed);

// Self attention
let attn = self.self_attn.forward(&active_normed, mask, cache);
let attn = self.post_attention_layernorm.forward(&attn);

// Residual + LAUREL
let attn_gated = mlxcel_core::add(active_prediction, &attn);

let sum = mlxcel_core::add(&attn_gated, &laurel_output);
let attn_laurel = mlxcel_core::multiply_scalar(&sum, std::f32::consts::FRAC_1_SQRT_2);

// FFN
let attn_norm = self.pre_feedforward_layernorm.forward(&attn_laurel);
let ffw = self.mlp.forward(&attn_norm);
let ffw_norm = self.post_feedforward_layernorm.forward(&ffw);
let ffw_gated = mlxcel_core::add(&attn_laurel, &ffw_norm);
let ffw_gated = mlxcel_core::astype(&ffw_gated, mlxcel_core::dtype::BFLOAT16);

// AltUp correct
let corrected = self.altup.correct(&predictions, &ffw_gated);

// Per-layer input processing
let first = &corrected[self.altup_active_idx];
let first = if self.altup_correct_scale {
mlxcel_core::multiply(first, &self.altup.correct_output_scale)
} else {
mlxcel_core::copy(first)
};

let first = self.per_layer_input_gate.forward(&first);
let first = mlxcel_core::compiled_geglu_approx_activation(&first, per_layer_input);
let first = self.per_layer_projection.forward(&first);
let first_prediction = self.post_per_layer_input_norm.forward(&first);

// Add first_prediction to corrected[1:].
let mut result = Vec::with_capacity(corrected.len());
result.push(mlxcel_core::copy(&corrected[0]));
for item in corrected.iter().skip(1) {
result.push(mlxcel_core::add(item, &first_prediction));
}

result
}

pub fn from_weights(
weights: &WeightMap,
config: &TextConfig,
Expand Down