Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/benchmark_results/model_tests_m1ultra.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Compatibility and performance testing for mlxcel models on **Mac Studio M1 Ultra
| gemma4 (E4B 8bit) | gemma-4-e4b-it-8bit | ✅ | 165.49 | 59.35 | - | mlx-lm: FAIL; only 39 tokens |
| gemma3n | gemma-3n-E2B-it-4bit | ✅ | 238.51 | 76.86 | - | mlx-lm: FAIL; only 69 tokens |
| gemma3n (E4B) | gemma-3n-E4B-it-4bit | ✅ | 169.94 | 60.18 | - | mlx-lm: FAIL; only 74 tokens |
| gemma3n (E4B bf16) | gemma-3n-E4B-it (bf16) | ✅ | 169.01 | 34.41 | 88% | mlx-lm: 39.02; bf16 prefill path retune (PR #727); bf16; only 72 tokens |
| gemma3n (E4B bf16) | gemma-3n-E4B-it (bf16) | ✅ | 169.01 | 35.65 | 91% | mlx-lm: 39.02; bf16; AltUp/MLP decode graph scheduling |
| recurrent_gemma | - | ⏳ | - | - | - | Griffin SSM+attention hybrid |

## EXAONE
Expand Down Expand Up @@ -294,7 +294,7 @@ Numbers are decode tok/s. `mlxcel vs mlx-lm` is `mlxcel / mlx-lm` as a percentag

- **Comparable text pairs**: 73
- **mlxcel >= mlx-lm**: 20 / 73 (27%)
- **mlxcel >= 90% parity**: 65 / 73 (89%)
- **mlxcel >= 90% parity**: 66 / 73 (90%)
- **Average mlxcel/mlx-lm**: 96% (median 97%, range 47%-113%)

### Aggregate (VLM, models with >=5 generated tokens both sides)
Expand Down Expand Up @@ -341,7 +341,7 @@ Numbers are decode tok/s. `mlxcel vs mlx-lm` is `mlxcel / mlx-lm` as a percentag
| gemma3-4b-4bit | 112.95 | 109.48 | **103%** |
| gemma3n-e2b-4bit | 76.86 | FAIL | - |
| gemma3n-e4b-4bit | 60.18 | FAIL | - |
| gemma3n-e4b-bf16 | 34.41 | 39.02 | 88% |
| gemma3n-e4b-bf16 | 35.65 | 39.02 | 91% |
| glm4-flash-4bit | 47.32 | 49.47 | 96% |
| gpt-oss-120b-4bit | 58.89 | 57.58 | **102%** |
| gpt-oss-20b-mxfp4 | 88.89 | 89.51 | 99% |
Expand Down
50 changes: 50 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,56 @@ std::unique_ptr<MlxArray> compiled_gelu_mlp_forward_fp16(
return std::make_unique<MlxArray>(std::move(down));
}

// Gemma3n dense MLP forward for non-quantized bf16 language MLP weights:
// cast input to bf16 -> gate/up -> gelu_approx or gelu_topk -> down -> bf16.
//
// Matmul operations intentionally stay outside mx::compile for the same reason
// as compiled_swiglu_mlp_forward_fp16 / compiled_gelu_mlp_forward_fp16: compiled
// matmul+transpose graphs can reuse the wrong constants across layers. The
// element-wise activation still uses the cached compiled kernels, while the
// whole MLP is built through one C++ bridge call instead of four Rust calls.
std::unique_ptr<MlxArray> gemma3n_mlp_forward(
const MlxArray& x,
const MlxArray& gate_weight,
const MlxArray& up_weight,
const MlxArray& down_weight,
const MlxArray* gate_bias,
const MlxArray* up_bias,
const MlxArray* down_bias,
float activation_sparsity,
float std_multiplier
) {
auto x_mlp = x.inner.dtype() == mlx::core::bfloat16
? x.inner
: mlx::core::astype(x.inner, mlx::core::bfloat16);

auto gate_t = mlx::core::transpose(gate_weight.inner);
auto gate = mlx::core::matmul(x_mlp, gate_t);
if (gate_bias) gate = mlx::core::add(gate, gate_bias->inner);

auto up_t = mlx::core::transpose(up_weight.inner);
auto up = mlx::core::matmul(x_mlp, up_t);
if (up_bias) up = mlx::core::add(up, up_bias->inner);

auto hidden = [&]() {
if (activation_sparsity > 0.0f) {
static auto compiled_topk = get_compiled_gelu_topk();
auto mult = array(std_multiplier);
auto activated = compiled_topk({gate, mult});
return mlx::core::multiply(activated[0], up);
}
static auto compiled_geglu_approx = get_compiled_geglu_approx();
auto activated = compiled_geglu_approx({gate, up});
return activated[0];
}();

auto down_t = mlx::core::transpose(down_weight.inner);
auto down = mlx::core::matmul(hidden, down_t);
if (down_bias) down = mlx::core::add(down, down_bias->inner);

return std::make_unique<MlxArray>(mlx::core::astype(down, mlx::core::bfloat16));
}

std::unique_ptr<MlxArray> transformer_layer_forward(
const MlxArray& x,
const MlxArray& attn_norm_weight,
Expand Down
16 changes: 16 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,22 @@ std::unique_ptr<MlxArray> compiled_gelu_mlp_forward_fp16(
const MlxArray* down_bias
);

// Gemma3n dense MLP forward for non-quantized bf16 language MLP weights:
// cast input to bf16 -> gate/up -> gelu_approx or gelu_topk -> down -> bf16.
// Keeps the same bf16 semantics as the Rust op-at-a-time path while collapsing
// the decode-hot MLP graph construction into one C++ bridge call.
std::unique_ptr<MlxArray> gemma3n_mlp_forward(
const MlxArray& x,
const MlxArray& gate_weight,
const MlxArray& up_weight,
const MlxArray& down_weight,
const MlxArray* gate_bias,
const MlxArray* up_bias,
const MlxArray* down_bias,
float activation_sparsity,
float std_multiplier
);

// Full transformer layer forward (maximum FFI reduction)
// Combines: attention + MLP + residuals + norms
std::unique_ptr<MlxArray> transformer_layer_forward(
Expand Down
17 changes: 17 additions & 0 deletions src/lib/mlxcel-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,23 @@ mod ffi {
down_bias: *const MlxArray,
) -> UniquePtr<MlxArray>;

/// Gemma3n dense MLP forward for non-quantized bf16 language MLP weights:
/// cast input to bf16, run gate/up + gelu_approx or gelu_topk + down,
/// then cast back to bf16. Preserves the Gemma3n precision policy
/// while avoiding several Rust↔C++ bridge round-trips per layer.
/// Used by: Gemma3n bf16 language MLP decode path
unsafe fn gemma3n_mlp_forward(
x: &MlxArray,
gate_weight: &MlxArray,
up_weight: &MlxArray,
down_weight: &MlxArray,
gate_bias: *const MlxArray,
up_bias: *const MlxArray,
down_bias: *const MlxArray,
activation_sparsity: f32,
std_multiplier: f32,
) -> UniquePtr<MlxArray>;

/// Full transformer layer forward (maximum FFI reduction)
unsafe fn transformer_layer_forward(
x: &MlxArray,
Expand Down
158 changes: 95 additions & 63 deletions src/models/gemma3n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

use crate::models::gemma3n_helpers::{
apply_softcap, compute_magnitude, mean_arrays, normalize_magnitudes,
normalize_magnitudes_from_idx, slice_layer_input, stack_arrays,
normalize_magnitudes_from_idx, slice_altup_plane, slice_layer_input,
split_altup_after_per_layer_update, split_altup_planes, stack_arrays,
};
use mlxcel_core::generate::LanguageModel;
use mlxcel_core::layers::{KVCache, Linear, RMSNorm, UnifiedEmbedding, UnifiedLinear};
Expand Down Expand Up @@ -249,8 +250,14 @@ impl AltUp {
mlxcel_core::tanh(&routed)
}

/// Predict: expand inputs through altup_num_inputs parallel paths
pub fn predict(&self, x: &[UniquePtr<MlxArray>]) -> Vec<UniquePtr<MlxArray>> {
/// Predict: expand inputs through altup_num_inputs parallel paths.
///
/// The decode-hot layer path consumes the stacked tensor directly and
/// avoids immediately slicing the four AltUp planes only to stack them
/// again during correction. Keeping this as one `[altup, B, L, hidden]`
/// graph island mirrors mlx-lm's tensor scheduling more closely while
/// preserving the public Vec-returning wrapper below for parity tests.
pub fn predict_stacked(&self, x: &[UniquePtr<MlxArray>]) -> UniquePtr<MlxArray> {
// x is [altup_num_inputs] arrays, each [B, L, hidden_size]
// Get active input for routing
let active = &x[self.altup_active_idx];
Expand All @@ -269,14 +276,11 @@ impl AltUp {
let all_coefs = mlxcel_core::reshape(&coefs, &[b, l, n, n]);
let all_coefs = mlxcel_core::transpose_axes(&all_coefs, &[0, 1, 3, 2]);

// Convert x to float32 for computation
let x_f32: Vec<_> = x
.iter()
.map(|arr| mlxcel_core::astype(arr, mlxcel_core::dtype::FLOAT32))
.collect();

// Stack x to [B, L, hidden, altup]
let x_stacked = stack_arrays(&x_f32, 0);
// Stack first, then cast once to match mlx-lm's `x.astype(mx.float32)`
// scheduling. This keeps the four AltUp planes in one graph island
// instead of creating one bf16→f32 cast node per plane.
let x_stacked_native = stack_arrays(x, 0);
let x_stacked = mlxcel_core::astype(&x_stacked_native, mlxcel_core::dtype::FLOAT32);
// x_stacked shape: [altup, B, L, hidden]
let x_permuted = mlxcel_core::transpose_axes(&x_stacked, &[1, 2, 3, 0]);
// x_permuted shape: [B, L, hidden, altup]
Expand All @@ -288,39 +292,31 @@ impl AltUp {

// Add residual
let predictions = mlxcel_core::add(&predictions, &x_stacked);
let predictions = mlxcel_core::astype(&predictions, mlxcel_core::array_dtype(&x[0]));

// Split back to individual arrays
let mut result = Vec::with_capacity(self.altup_num_inputs);
for i in 0..self.altup_num_inputs {
let start = vec![i as i32, 0, 0, 0];
let hidden = mlxcel_core::array_shape(&x_f32[0])[2];
let stop = vec![(i + 1) as i32, b, l, hidden];
let sliced = mlxcel_core::slice(&predictions, &start, &stop);
let squeezed = mlxcel_core::squeeze_axis(&sliced, 0);
result.push(squeezed);
}
mlxcel_core::astype(&predictions, mlxcel_core::array_dtype(&x[0]))
}

result
/// Predict: expand inputs through altup_num_inputs parallel paths.
pub fn predict(&self, x: &[UniquePtr<MlxArray>]) -> Vec<UniquePtr<MlxArray>> {
let predictions = self.predict_stacked(x);
split_altup_planes(&predictions, self.altup_num_inputs)
}

/// Correct: apply correction to predictions based on activated output
pub fn correct(
/// Correct: apply correction to stacked predictions based on activated output.
pub fn correct_stacked(
&self,
predictions: &[UniquePtr<MlxArray>],
predictions: &MlxArray,
active_prediction: &MlxArray,
activated: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
) -> UniquePtr<MlxArray> {
let modalities = self.compute_router_modalities(activated);

// correction_coefs output shape: [B, L, altup_num_inputs]
let all_coefs = self.correction_coefs.forward(&modalities);
let one = mlxcel_core::full_f32(&[1], 1.0, mlxcel_core::array_dtype(&all_coefs));
let all_coefs = mlxcel_core::add(&all_coefs, &one);

// Get active prediction
let active_x = &predictions[self.altup_active_idx];
// innovation = activated - active_prediction
let innovation = mlxcel_core::subtract(activated, active_x);
let innovation = mlxcel_core::subtract(activated, active_prediction);

let shape = mlxcel_core::array_shape(&all_coefs);
let b = shape[0];
Expand All @@ -340,25 +336,24 @@ impl AltUp {
// Element-wise multiply: [1, B, L, hidden] * [altup, B, L, 1] = [altup, B, L, hidden]
let correction = mlxcel_core::multiply(&innovation_expanded, &coefs_expanded);

// Stack predictions and add correction
let preds_stacked = stack_arrays(predictions, 0);
let corrected = mlxcel_core::add(&preds_stacked, &correction);
// Add correction to the existing stacked prediction tensor.
let corrected = mlxcel_core::add(predictions, &correction);

// Cast back to original dtype
let original_dtype = mlxcel_core::array_dtype(activated);
let corrected = mlxcel_core::astype(&corrected, original_dtype);

// Split back to individual arrays
let mut result = Vec::with_capacity(self.altup_num_inputs);
for i in 0..self.altup_num_inputs {
let start = vec![i as i32, 0, 0, 0];
let stop = vec![(i + 1) as i32, b, l, hidden];
let sliced = mlxcel_core::slice(&corrected, &start, &stop);
let squeezed = mlxcel_core::squeeze_axis(&sliced, 0);
result.push(squeezed);
}
mlxcel_core::astype(&corrected, original_dtype)
}

result
/// Correct: apply correction to predictions based on activated output.
pub fn correct(
&self,
predictions: &[UniquePtr<MlxArray>],
activated: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
let predictions_stacked = stack_arrays(predictions, 0);
let active_prediction = &predictions[self.altup_active_idx];
let corrected = self.correct_stacked(&predictions_stacked, active_prediction, activated);
split_altup_planes(&corrected, self.altup_num_inputs)
}

pub fn from_weights(
Expand Down Expand Up @@ -691,6 +686,41 @@ impl MLP {
}

pub fn forward(&self, x: &MlxArray) -> UniquePtr<MlxArray> {
if let (Some(gate), Some(up), Some(down)) = (
self.gate_proj.regular_weight(),
self.up_proj.regular_weight(),
self.down_proj.regular_weight(),
) {
let gate_bias_ptr = gate
.bias
.as_ref()
.map(|b| b.as_ref().unwrap() as *const MlxArray)
.unwrap_or(std::ptr::null());
let up_bias_ptr = up
.bias
.as_ref()
.map(|b| b.as_ref().unwrap() as *const MlxArray)
.unwrap_or(std::ptr::null());
let down_bias_ptr = down
.bias
.as_ref()
.map(|b| b.as_ref().unwrap() as *const MlxArray)
.unwrap_or(std::ptr::null());
return unsafe {
mlxcel_core::gemma3n_mlp_forward(
x,
&gate.weight,
&up.weight,
&down.weight,
gate_bias_ptr,
up_bias_ptr,
down_bias_ptr,
self.activation_sparsity,
self.std_multiplier,
)
};
}

let x_cast;
let x_mlp = if mlxcel_core::array_dtype(x) == mlxcel_core::dtype::BFLOAT16 {
x
Expand Down Expand Up @@ -784,12 +814,14 @@ impl DecoderLayer {
cache: &mut KVCache,
per_layer_input: &MlxArray,
) -> Vec<UniquePtr<MlxArray>> {
// AltUp predict
let predictions = self.altup.predict(x);
let active_prediction = &predictions[self.altup_active_idx];
// AltUp predict. Keep the prediction tensor stacked until correction
// so decode avoids slicing four planes and stacking them again within
// the same layer.
let predictions = self.altup.predict_stacked(x);
let active_prediction = slice_altup_plane(&predictions, self.altup_active_idx);

// Input layernorm
let active_normed = self.input_layernorm.forward(active_prediction);
let active_normed = self.input_layernorm.forward(&active_prediction);

// LAUREL
let laurel_output = self.laurel.forward(&active_normed);
Expand All @@ -799,7 +831,7 @@ impl DecoderLayer {
let attn = self.post_attention_layernorm.forward(&attn);

// Residual + LAUREL
let attn_gated = mlxcel_core::add(active_prediction, &attn);
let attn_gated = mlxcel_core::add(&active_prediction, &attn);

let sum = mlxcel_core::add(&attn_gated, &laurel_output);
let attn_laurel = mlxcel_core::multiply_scalar(&sum, std::f32::consts::FRAC_1_SQRT_2);
Expand All @@ -812,29 +844,29 @@ impl DecoderLayer {
let ffw_gated = mlxcel_core::astype(&ffw_gated, mlxcel_core::dtype::BFLOAT16);

// AltUp correct
let corrected = self.altup.correct(&predictions, &ffw_gated);
let corrected = self
.altup
.correct_stacked(&predictions, &active_prediction, &ffw_gated);

// Per-layer input processing
let first = &corrected[self.altup_active_idx];
let first = slice_altup_plane(&corrected, self.altup_active_idx);
let first = if self.altup_correct_scale {
mlxcel_core::multiply(first, &self.altup.correct_output_scale)
mlxcel_core::multiply(&first, &self.altup.correct_output_scale)
} else {
mlxcel_core::copy(first)
first
};

let first = self.per_layer_input_gate.forward(&first);
let first = mlxcel_core::compiled_geglu_approx_activation(&first, per_layer_input);
let first = self.per_layer_projection.forward(&first);
let first_prediction = self.post_per_layer_input_norm.forward(&first);

// Add first_prediction to corrected[1:]
let mut result = Vec::with_capacity(corrected.len());
result.push(mlxcel_core::copy(&corrected[0]));
for item in corrected.iter().skip(1) {
result.push(mlxcel_core::add(item, &first_prediction));
}

result
// Add first_prediction to corrected[1:] and split for the next layer.
split_altup_after_per_layer_update(
&corrected,
&first_prediction,
self.altup.altup_num_inputs,
)
}

pub fn from_weights(
Expand Down
Loading