Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 120 additions & 3 deletions src/execution/quant_advisor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use std::path::Path;

use mlxcel_core::hardware::{HardwareCapabilities, QuantRecommendation, recommend_quantization};
use mlxcel_core::weights::weight_footprint_bytes;

// ── Model-size estimation ─────────────────────────────────────────────────────

Expand Down Expand Up @@ -172,6 +173,12 @@ pub struct QuantAdvice {
pub recommendation: QuantRecommendation,
/// Estimated parameter count in billions (None if config parsing failed).
pub estimated_params_billions: Option<f64>,
/// Byte-accurate weight footprint read from the safetensors header before
/// any tensors are loaded. `Some(bytes)` when available from
/// `model.safetensors.index.json` (sharded) or a single `model.safetensors`
/// binary header; `None` when neither is present and the analytical estimate
/// from `estimated_params_billions` is the only available sizing signal.
pub exact_weight_bytes: Option<u64>,
/// True when the model config declares BFloat16 weights.
pub model_uses_bfloat16: bool,
}
Expand All @@ -182,14 +189,34 @@ pub struct QuantAdvice {
/// [`recommend_quantization`] against the provided hardware capabilities.
///
/// When `model_params_override` is `Some(n)`, that value is used instead of
/// the estimate from `config.json`.
/// both the safetensors-derived and config-derived estimates.
///
/// Resolution order for the size fed to [`recommend_quantization`]:
/// 1. `model_params_override` (caller-supplied explicit value)
/// 2. Exact bytes from the safetensors header, converted to billions via
/// `bytes / 2 / 1e9` (assumes FP16 as the reference dtype).
/// 3. Analytical estimate from `config.json`.
/// 4. Hard-coded 7 B fallback.
pub fn advise_quantization(
model_path: &Path,
hw: &HardwareCapabilities,
model_params_override: Option<f64>,
) -> QuantAdvice {
let estimated_params = estimate_model_params_billions(model_path);
let params = model_params_override.or(estimated_params).unwrap_or(7.0); // safe fallback: assume 7B when unknown
let exact_weight_bytes = weight_footprint_bytes(model_path);

// Convert exact bytes to a billions-of-parameters estimate.
// The safetensors total_size is raw parameter bytes (e.g. 2 bytes per BF16
// parameter). Dividing by 2 yields an FP16-equivalent parameter count in
// bytes; dividing by 1e9 converts to billions. This is conservative for
// INT8/INT4 models (they will appear larger than they are), which is the
// safe direction for memory-fit recommendations.
let exact_params_billions: Option<f64> = exact_weight_bytes.map(|b| b as f64 / 2.0 / 1e9);

let params = model_params_override
.or(exact_params_billions)
.or(estimated_params)
.unwrap_or(7.0); // safe fallback: assume 7B when unknown

let recommendation = recommend_quantization(params, hw.unified_memory_gb, hw);

Expand All @@ -198,10 +225,24 @@ pub fn advise_quantization(
QuantAdvice {
recommendation,
estimated_params_billions: estimated_params,
exact_weight_bytes,
model_uses_bfloat16: uses_bf16,
}
}

/// Format a byte count as a human-readable string (GiB, MiB, or exact bytes).
fn format_bytes(bytes: u64) -> String {
const GIB: u64 = 1024 * 1024 * 1024;
const MIB: u64 = 1024 * 1024;
if bytes >= GIB {
format!("{:.2} GiB ({bytes} bytes)", bytes as f64 / GIB as f64)
} else if bytes >= MIB {
format!("{:.1} MiB ({bytes} bytes)", bytes as f64 / MIB as f64)
} else {
format!("{bytes} bytes")
}
}

/// Print a human-readable quantization recommendation to stdout.
pub fn print_quant_advice(advice: &QuantAdvice, hw: &HardwareCapabilities) {
println!();
Expand All @@ -217,7 +258,18 @@ pub fn print_quant_advice(advice: &QuantAdvice, hw: &HardwareCapabilities) {
});
println!(" Memory: {} GB unified", hw.unified_memory_gb);

if let Some(params) = advice.estimated_params_billions {
if let Some(exact_bytes) = advice.exact_weight_bytes {
println!(
" Model size: {} (exact, from safetensors header)",
format_bytes(exact_bytes)
);
if let Some(params) = advice.estimated_params_billions {
println!(
" ~{:.1}B parameters (analytical estimate for reference)",
params
);
}
} else if let Some(params) = advice.estimated_params_billions {
println!(" Model size: ~{:.1}B parameters (estimated)", params);
} else {
println!(" Model size: unknown (could not parse config.json)");
Expand Down Expand Up @@ -350,4 +402,69 @@ mod tests {
}
);
}

#[test]
fn advise_quantization_exact_bytes_field_is_none_for_empty_dir() {
use mlxcel_core::hardware::{AppleSiliconGen, HardwareCapabilities};

let hw = HardwareCapabilities {
silicon_gen: AppleSiliconGen::M5,
gpu_core_count: 10,
has_neural_accelerator: false,
metal_version: 4,
macos_supports_na: false,
memory_bandwidth_gbps: 100.0,
unified_memory_gb: 16,
};

let tmp = tempfile::tempdir().unwrap();
let advice = advise_quantization(tmp.path(), &hw, None);
assert_eq!(advice.exact_weight_bytes, None);
}

#[test]
fn advise_quantization_uses_exact_bytes_from_index() {
use mlxcel_core::hardware::{AppleSiliconGen, HardwareCapabilities};
use std::io::Write;

let hw = HardwareCapabilities {
silicon_gen: AppleSiliconGen::M5,
gpu_core_count: 10,
has_neural_accelerator: false,
metal_version: 4,
macos_supports_na: false,
memory_bandwidth_gbps: 100.0,
unified_memory_gb: 16,
};

// Write an index.json with a known total_size (7B FP16 = ~14 GB = 14_000_000_000 bytes).
let tmp = tempfile::tempdir().unwrap();
let index_json =
r#"{"metadata": {"total_size": 14000000000}, "weight_map": {"w": "x.safetensors"}}"#;
let mut f = std::fs::File::create(tmp.path().join("model.safetensors.index.json")).unwrap();
f.write_all(index_json.as_bytes()).unwrap();

let advice = advise_quantization(tmp.path(), &hw, None);
assert_eq!(advice.exact_weight_bytes, Some(14_000_000_000));
}

#[test]
fn format_bytes_gib() {
// 2 GiB exactly.
let s = format_bytes(2 * 1024 * 1024 * 1024);
assert!(s.contains("GiB"), "expected GiB in: {s}");
assert!(s.contains("2147483648"), "expected raw bytes in: {s}");
}

#[test]
fn format_bytes_mib() {
let s = format_bytes(5 * 1024 * 1024);
assert!(s.contains("MiB"), "expected MiB in: {s}");
}

#[test]
fn format_bytes_small() {
let s = format_bytes(42);
assert_eq!(s, "42 bytes");
}
}
Loading