Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions src/execution/quant_advisor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

use std::path::Path;

use mlxcel_core::hardware::{HardwareCapabilities, QuantRecommendation, recommend_quantization};
use mlxcel_core::hardware::{
HardwareCapabilities, KvCacheParams, QuantRecommendation, kv_cache_bytes_from_params,
recommend_quantization,
};
use mlxcel_core::weights::weight_footprint_bytes;

// ── Model-size estimation ─────────────────────────────────────────────────────
Expand Down Expand Up @@ -181,12 +184,15 @@ pub struct QuantAdvice {
pub exact_weight_bytes: Option<u64>,
/// True when the model config declares BFloat16 weights.
pub model_uses_bfloat16: bool,
/// KV cache memory estimate in bytes (None if architecture info is unavailable).
pub kv_cache_bytes: Option<u64>,
}

/// Produce a complete quantization recommendation for a model directory.
///
/// Uses the model's `config.json` to estimate the parameter count, then calls
/// [`recommend_quantization`] against the provided hardware capabilities.
/// Uses the model's `config.json` to estimate the parameter count and KV cache
/// memory requirements, then calls [`recommend_quantization`] against the
/// provided hardware capabilities.
///
/// When `model_params_override` is `Some(n)`, that value is used instead of
/// both the safetensors-derived and config-derived estimates.
Expand All @@ -197,6 +203,12 @@ pub struct QuantAdvice {
/// `bytes / 2 / 1e9` (assumes FP16 as the reference dtype).
/// 3. Analytical estimate from `config.json`.
/// 4. Hard-coded 7 B fallback.
///
/// The KV cache headroom passed to [`recommend_quantization`] is derived from
/// the model architecture when it can be extracted from `config.json`. The
/// default context length is 8192 tokens; `int8_kv` defaults to `false`.
/// When architecture fields are unavailable the function falls back to the
/// built-in 2 GiB constant (`kv_cache_headroom_bytes = None`).
pub fn advise_quantization(
model_path: &Path,
hw: &HardwareCapabilities,
Expand All @@ -218,7 +230,10 @@ pub fn advise_quantization(
.or(estimated_params)
.unwrap_or(7.0); // safe fallback: assume 7B when unknown

let recommendation = recommend_quantization(params, hw.unified_memory_gb, hw);
// Attempt to derive KV cache headroom from the model config.
let kv_bytes = estimate_kv_cache_bytes_from_path(model_path, 8192, false);

let recommendation = recommend_quantization(params, hw.unified_memory_gb, hw, kv_bytes);

let uses_bf16 = model_uses_bfloat16(model_path);

Expand All @@ -227,6 +242,7 @@ pub fn advise_quantization(
estimated_params_billions: estimated_params,
exact_weight_bytes,
model_uses_bfloat16: uses_bf16,
kv_cache_bytes: kv_bytes,
}
}

Expand All @@ -243,6 +259,72 @@ fn format_bytes(bytes: u64) -> String {
}
}

/// Estimate KV cache memory in bytes from a model's `config.json`.
///
/// Returns `None` when the required architecture fields (`num_hidden_layers`,
/// `num_key_value_heads`, `hidden_size`, `num_attention_heads`) cannot be
/// extracted.
///
/// `ctx_len` is the requested context length (tokens); `int8_kv` controls
/// whether 1-byte (INT8) or 2-byte (FP16) KV storage is assumed.
pub fn estimate_kv_cache_bytes_from_path(
model_path: &Path,
ctx_len: u64,
int8_kv: bool,
) -> Option<u64> {
let config_path = model_path.join("config.json");
let config_str = std::fs::read_to_string(&config_path).ok()?;
let config: serde_json::Value = serde_json::from_str(&config_str).ok()?;
estimate_kv_cache_bytes_from_config(&config, ctx_len, int8_kv)
}

fn estimate_kv_cache_bytes_from_config(
config: &serde_json::Value,
ctx_len: u64,
int8_kv: bool,
) -> Option<u64> {
let text_cfg = config.get("text_config").unwrap_or(config);

let num_layers = text_cfg
.get("num_hidden_layers")
.or_else(|| text_cfg.get("n_layers"))
.or_else(|| text_cfg.get("num_layers"))
.and_then(|v| v.as_u64())?;

let hidden_size = text_cfg
.get("hidden_size")
.or_else(|| text_cfg.get("d_model"))
.and_then(|v| v.as_u64())?;

let num_heads = text_cfg
.get("num_attention_heads")
.or_else(|| text_cfg.get("num_heads"))
.and_then(|v| v.as_u64())
.unwrap_or(1);

let num_kv_heads = text_cfg
.get("num_key_value_heads")
.and_then(|v| v.as_u64())
.unwrap_or(num_heads);

// head_dim is usually hidden_size / num_heads; fall back to 64 if num_heads is 0.
let head_dim = if num_heads > 0 {
hidden_size / num_heads
} else {
64
};

let params = KvCacheParams {
num_layers,
num_kv_heads,
head_dim,
int8_kv,
ctx_len,
batch: 1,
};
Some(kv_cache_bytes_from_params(&params))
}

/// Print a human-readable quantization recommendation to stdout.
pub fn print_quant_advice(advice: &QuantAdvice, hw: &HardwareCapabilities) {
println!();
Expand Down
Loading