Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/lib/mlxcel-core/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,25 @@ impl CxxGenerator {
ffi::export_to_dot_pair(&path, &next_tok, &next_log);
}
}
// Optional Metal GPU capture of one warm decode token for
// per-kernel profiling vs mlx-lm. Fires at n==2 so
// all decode kernels are JIT-cached. Requires the process to be
// launched with `MTL_CAPTURE_ENABLED=1`; writes a `.gputrace`
// bundle to the given path, comparable with mlx-lm's
// `mx.metal.start_capture`.
if n == 2 {
if let Ok(path) = std::env::var("MLXCEL_CAPTURE_DECODE") {
ffi::metal_start_capture(&path);
ffi::eval(&next_tok);
ffi::metal_stop_capture();
// Exit immediately so the GPU trace document finalizes
// with exactly one captured decode token and no further
// GPU work polluting it (mirrors mlx-lm's capture-script
// lifecycle). Capture mode is a profiling-only path.
eprintln!("[capture] wrote one decode token to {path}");
std::process::exit(0);
}
}
if force_sync {
ffi::eval(&next_tok);
} else {
Expand Down
114 changes: 104 additions & 10 deletions src/models/gemma3n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,13 +665,74 @@ impl Gemma3nAttention {

// MLP with gelu_topk activation.
pub struct MLP {
pub gate_proj: UnifiedLinear,
pub up_proj: UnifiedLinear,
pub gate_proj: MlpInputProjection,
pub up_proj: MlpInputProjection,
pub down_proj: UnifiedLinear,
pub activation_sparsity: f32,
pub std_multiplier: f32,
}

// M5 non-quantized Gemma3n decode GEMVs stream gate/up weights faster when MLX
// sees materialized transposed weights. Quantized layers keep UnifiedLinear so
// their specialized 4bit path is unchanged.
pub enum MlpInputProjection {
Standard(UnifiedLinear),
Pretransposed {
weight_t: UniquePtr<MlxArray>,
bias: Option<UniquePtr<MlxArray>>,
},
}

impl MlpInputProjection {
fn from_weights_maybe_pretransposed(
weights: &WeightMap,
prefix: &str,
group_size: i32,
bits: i32,
) -> Result<Self, String> {
let hw = mlxcel_core::hardware::get_hardware();
let is_m5_na = hw.has_neural_accelerator && hw.macos_supports_na;
let scales_name = format!("{}.scales", prefix);
if !is_m5_na || weights.contains_key(&scales_name) {
return Ok(Self::Standard(UnifiedLinear::from_weights(
weights, prefix, group_size, bits,
)?));
}

let weight_name = format!("{}.weight", prefix);
let weight = weights
.get(&weight_name)
.ok_or_else(|| format!("Weight not found: {}", weight_name))?;
let weight_t = mlxcel_core::transpose(weight);
let weight_t = mlxcel_core::contiguous(&weight_t, false);
mlxcel_core::eval(&weight_t);

let bias_name = format!("{}.bias", prefix);
let bias = weights.get(&bias_name).map(|b| mlxcel_core::copy(b));
Ok(Self::Pretransposed { weight_t, bias })
}

fn forward(&self, x: &MlxArray) -> UniquePtr<MlxArray> {
match self {
Self::Standard(linear) => linear.forward(x),
Self::Pretransposed { weight_t, bias } => {
let out = mlxcel_core::matmul(x, weight_t);
match bias {
Some(bias) => mlxcel_core::add(&out, bias),
None => out,
}
}
}
}

fn regular_weight(&self) -> Option<&Linear> {
match self {
Self::Standard(linear) => linear.regular_weight(),
Self::Pretransposed { .. } => None,
}
}
}

/// #60 introduced a fused Gemma3n decode path (stacked AltUp predict/
/// correct plus the `gemma3n_mlp_forward` bridge call) that cuts Rust↔C++
/// graph-construction overhead. It improves decode on Apple Silicon without a
Expand Down Expand Up @@ -775,14 +836,18 @@ impl MLP {
.map(|q| q.bits as i32)
.unwrap_or(4);

let gate_proj = UnifiedLinear::from_weights(
let gate_proj = MlpInputProjection::from_weights_maybe_pretransposed(
weights,
&format!("{}.gate_proj", prefix),
group_size,
bits,
)?;
let up_proj =
UnifiedLinear::from_weights(weights, &format!("{}.up_proj", prefix), group_size, bits)?;
let up_proj = MlpInputProjection::from_weights_maybe_pretransposed(
weights,
&format!("{}.up_proj", prefix),
group_size,
bits,
)?;
let down_proj = UnifiedLinear::from_weights(
weights,
&format!("{}.down_proj", prefix),
Expand Down Expand Up @@ -962,9 +1027,12 @@ impl DecoderLayer {

// Add first_prediction to corrected[1:].
let mut result = Vec::with_capacity(corrected.len());
result.push(mlxcel_core::copy(&corrected[0]));
for item in corrected.iter().skip(1) {
result.push(mlxcel_core::add(item, &first_prediction));
let mut corrected = corrected.into_iter();
if let Some(first) = corrected.next() {
result.push(first);
}
for item in corrected {
result.push(mlxcel_core::add(&item, &first_prediction));
}

result
Expand Down Expand Up @@ -1068,6 +1136,7 @@ impl DecoderLayer {
// Language Model.
pub struct Gemma3nLanguageModel {
pub embed_tokens: UnifiedEmbedding,
pub embed_tokens_weight_t: Option<UniquePtr<MlxArray>>,
pub embed_tokens_per_layer: UnifiedEmbedding,
pub per_layer_model_projection: UnifiedLinear,
pub per_layer_projection_norm: RMSNorm,
Expand All @@ -1082,6 +1151,29 @@ pub struct Gemma3nLanguageModel {
}

impl Gemma3nLanguageModel {
fn pretranspose_large_m5_embedding(
embedding: &UnifiedEmbedding,
) -> Option<UniquePtr<MlxArray>> {
let hw = mlxcel_core::hardware::get_hardware();
if embedding.is_quantized() || !(hw.has_neural_accelerator && hw.macos_supports_na) {
return None;
}

// The tied LM head is a very wide decode GEMV; materializing the
// transpose improves M5 bandwidth on non-quantized Gemma3n.
let weight_t = mlxcel_core::transpose(embedding.weight());
let weight_t = mlxcel_core::contiguous(&weight_t, false);
mlxcel_core::eval(&weight_t);
Some(weight_t)
}

fn lm_head(&self, out: &MlxArray) -> UniquePtr<MlxArray> {
match &self.embed_tokens_weight_t {
Some(weight_t) => mlxcel_core::matmul(out, weight_t),
None => self.embed_tokens.as_linear(out),
}
}

pub fn forward(&self, inputs: &MlxArray, caches: &mut [KVCache]) -> UniquePtr<MlxArray> {
// Embed tokens
let h = self.embed_tokens.forward(inputs);
Expand Down Expand Up @@ -1177,7 +1269,7 @@ impl Gemma3nLanguageModel {
};

// LM head (tied embeddings)
let mut logits = self.embed_tokens.as_linear(&out);
let mut logits = self.lm_head(&out);

// Apply logit softcapping if configured
if let Some(cap) = self.config.final_logit_softcapping {
Expand Down Expand Up @@ -1338,7 +1430,7 @@ impl Gemma3nLanguageModel {
} else {
mlxcel_core::astype(&out, mlxcel_core::array_dtype(self.embed_tokens.weight()))
};
let mut logits = self.embed_tokens.as_linear(&out);
let mut logits = self.lm_head(&out);

if let Some(cap) = self.config.final_logit_softcapping {
logits = apply_softcap(&logits, cap);
Expand Down Expand Up @@ -1377,6 +1469,7 @@ impl Gemma3nLanguageModel {
group_size,
bits,
)?;
let embed_tokens_weight_t = Self::pretranspose_large_m5_embedding(&embed_tokens);
let embed_tokens_per_layer = UnifiedEmbedding::from_weights(
weights,
&format!("{}.embed_tokens_per_layer", prefix),
Expand Down Expand Up @@ -1473,6 +1566,7 @@ impl Gemma3nLanguageModel {

Ok(Self {
embed_tokens,
embed_tokens_weight_t,
embed_tokens_per_layer,
per_layer_model_projection,
per_layer_projection_norm,
Expand Down