Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/commands/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ fn print_runtime_setup(runtime: &RuntimeSetup) {
max_memory as f64 / (1024.0 * 1024.0 * 1024.0)
);
}
// Issue #55: surface the soft allocator cap when the operator set one
// via MLXCEL_MEMORY_LIMIT, so the preflight intent is visible at boot.
if let Some(memory_limit) = runtime.memory_limit_bytes {
println!(
"MLX allocator memory limit: {:.1} GB (MLXCEL_MEMORY_LIMIT)",
memory_limit as f64 / (1024.0 * 1024.0 * 1024.0)
);
}
}

fn load_generation_model(
Expand All @@ -122,7 +130,29 @@ fn load_generation_model(
load_model(&args.model.model)
}?;
let load_elapsed = load_start.elapsed();
println!("Model loaded in {:.3}s.", load_elapsed.as_secs_f64());
// Issue #55: surface "resident after load" so operators (and the
// capstone preflight #56) can see how much MLX-allocator memory the
// model actually consumed once weight realisation finished. On
// Apple Silicon (Metal) this reads from the Metal allocator; on
// Linux/CUDA from the CUDA allocator; on CPU-only it reads from the
// no-gpu common allocator. Each backend may use a different
// definition of "active", but the number is always whatever MLX
// itself will compare against `memory_limit()` next.
let snap = mlxcel_core::memory::snapshot();
println!(
"Model loaded in {:.3}s (resident: {:.2} GB, peak: {:.2} GB).",
load_elapsed.as_secs_f64(),
snap.active_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
snap.peak_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
);
tracing::info!(
active_bytes = snap.active_bytes,
peak_bytes = snap.peak_bytes,
cache_bytes = snap.cache_bytes,
limit_bytes = snap.limit_bytes,
load_seconds = load_elapsed.as_secs_f64(),
"Model resident after load",
);
Ok(result)
}

Expand Down
39 changes: 39 additions & 0 deletions src/execution/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ use std::fmt;

const RUNTIME_DEVICE_ENV: &str = "MLXCEL_DEVICE";
const WIRED_LIMIT_ENV: &str = "MLXCEL_WIRED_LIMIT";
/// Issue #55: optional soft cap on the MLX allocator. When set, the
/// runtime calls `mlxcel_core::memory::set_memory_limit(...)` at startup
/// so MLX raises an exception once allocations would push the working
/// set past this value, instead of thrashing or OOM-killing the process.
/// Used by the future preflight capstone (#56). Accepts the same syntax
/// as `MLXCEL_WIRED_LIMIT`: plain bytes, `NGB`, or `NMB`. Unset means
/// "do not override MLX's default limit".
const MEMORY_LIMIT_ENV: &str = "MLXCEL_MEMORY_LIMIT";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuntimeDevice {
Expand Down Expand Up @@ -55,6 +63,10 @@ impl fmt::Display for RuntimeDevice {
pub struct RuntimeSetup {
pub device: RuntimeDevice,
pub wired_limit_bytes: Option<usize>,
/// Soft MLX allocator memory limit applied via `MLXCEL_MEMORY_LIMIT`
/// (issue #55). `None` when the env var was unset or invalid and
/// MLX's default limit is in effect.
pub memory_limit_bytes: Option<usize>,
pub invalid_device_override: Option<String>,
}

Expand All @@ -78,9 +90,16 @@ pub fn initialize_runtime() -> RuntimeSetup {
None
};

// Issue #55: apply optional soft allocator cap regardless of device.
// The MLX no-gpu CPU allocator also honours `set_memory_limit()`, so
// the preflight (#56) can use this on Linux/CI just as on Apple
// Silicon.
let memory_limit_bytes = resolve_memory_limit();

RuntimeSetup {
device,
wired_limit_bytes,
memory_limit_bytes,
invalid_device_override,
}
}
Expand Down Expand Up @@ -118,6 +137,26 @@ fn resolve_wired_limit() -> Option<usize> {
}
}

/// Resolve the MLX allocator soft limit from MLXCEL_MEMORY_LIMIT (issue #55).
///
/// Returns the limit actually applied to MLX, or `None` when the env var
/// is unset / explicitly disabled. This is the hook the capstone preflight
/// (#56) drives when a model is too large to fit comfortably — calling
/// `mlxcel_core::memory::set_memory_limit` makes MLX raise an exception
/// during evaluation instead of thrashing the system allocator.
fn resolve_memory_limit() -> Option<usize> {
let raw = std::env::var(MEMORY_LIMIT_ENV).ok();
let bytes = match raw.as_deref() {
Some("0") | Some("none") | Some("NONE") | None | Some("") => return None,
Some(s) => parse_memory_size(s)?,
};
if bytes == 0 {
return None;
}
mlxcel_core::memory::set_memory_limit(bytes as u64);
Some(bytes)
}

/// Parse a memory size string: plain bytes, "NGB", or "NMB".
fn parse_memory_size(s: &str) -> Option<usize> {
let s = s.trim().to_ascii_uppercase();
Expand Down
33 changes: 33 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3815,6 +3815,39 @@ size_t get_wired_limit() {
return 0;
}

// MLX runtime memory accounting (issue #55) — thin one-line forwarders to
// the canonical entry points in `mlx/memory.h`. The active allocator
// (Metal / CUDA / no-gpu CommonAllocator) decides what each value means;
// see the header comment in `mlx_cxx_bridge.h` for the cross-backend
// semantics.
size_t get_active_memory() {
return mlx::core::get_active_memory();
}

size_t get_peak_memory() {
return mlx::core::get_peak_memory();
}

size_t get_cache_memory() {
return mlx::core::get_cache_memory();
}

size_t set_memory_limit(size_t limit) {
return mlx::core::set_memory_limit(limit);
}

size_t get_memory_limit() {
return mlx::core::get_memory_limit();
}

size_t set_cache_limit(size_t limit) {
return mlx::core::set_cache_limit(limit);
}

void reset_peak_memory() {
mlx::core::reset_peak_memory();
}

size_t gpu_max_memory_size() {
auto& info = mlx::core::device_info();
// Metal backend uses "max_recommended_working_set_size"
Expand Down
19 changes: 19 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,25 @@ void set_default_device(bool gpu);
size_t set_wired_limit(size_t limit);
size_t get_wired_limit();

// MLX runtime memory accounting (issue #55).
//
// These wrap `mlx::core::get_active_memory()` / `get_peak_memory()` /
// `get_cache_memory()` / `set_memory_limit()` / `get_memory_limit()` /
// `set_cache_limit()` / `reset_peak_memory()` from `mlx/memory.h`. The
// numbers are populated by whichever allocator is active (Metal, CUDA, or
// the no-gpu common allocator) — see the per-backend implementations in
// `mlx/backend/<metal|cuda|no_gpu>/allocator.cpp`. On the no-gpu CPU
// allocator `get_cache_memory()` / `set_cache_limit()` are inert no-ops
// and return 0 by design; this matches MLX upstream semantics and lets
// the same Rust wrapper compile and run on Linux without panicking.
size_t get_active_memory();
size_t get_peak_memory();
size_t get_cache_memory();
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_cache_limit(size_t limit);
void reset_peak_memory();

// GPU memory info (works across Metal and CUDA backends)
size_t gpu_max_memory_size();

Expand Down
39 changes: 39 additions & 0 deletions src/lib/mlxcel-core/src/ffi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,45 @@ fn test_memory_functions() {
set_wired_limit(0);
}

#[test]
fn test_runtime_memory_apis_smoke(/* issue #55 */) {
// FFI smoke test: the raw runtime memory APIs (`get_active_memory`,
// `get_peak_memory`, `get_memory_limit`, `set_memory_limit`,
// `reset_peak_memory`) compile, link, and return plausible values on
// every backend mlxcel currently builds for. The typed-wrapper
// module `crate::memory` has the cross-platform / monotonicity
// assertions; this test just guards the raw cxx surface.

// Force at least one allocation against the MLX allocator so the
// counters have something to report.
let arr = from_slice_f32(&[1.0_f32; 1024], &[1024]);
eval(&arr);

// Counters return usize on the cxx boundary.
let _active = get_active_memory();
let _peak = get_peak_memory();
let _cache = get_cache_memory();
let _limit = get_memory_limit();

// `set_memory_limit` must return the previous limit so callers can
// restore it. Round-trip with a huge cap to avoid evicting any live
// arrays held by parallel tests.
let original = get_memory_limit();
let huge: usize = 1usize << 40;
let prev = set_memory_limit(huge);
assert_eq!(
prev, original,
"set_memory_limit should return the previous limit",
);
// Restore.
let _ = set_memory_limit(original);

// `reset_peak_memory` must execute without panicking. We do not
// assert what `get_peak_memory` returns afterwards because parallel
// tests sharing this process keep allocating arrays.
reset_peak_memory();
}

#[test]
fn test_scalar_helpers_preserve_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
Expand Down
36 changes: 36 additions & 0 deletions src/lib/mlxcel-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,37 @@ mod ffi {
/// Get current wired memory limit
fn get_wired_limit() -> usize;

// MLX runtime memory accounting (issue #55).
//
// These are raw bridge entry points wired to `mlx::core::get_*_memory`
// and friends in `mlx/memory.h`. The active allocator (Metal /
// CUDA / no-gpu CommonAllocator) decides what each value means.
// Prefer the typed wrappers in `crate::memory` over calling these
// directly — they return `u64` for cross-platform clarity and
// bundle the four most useful counters into a single snapshot.

/// Bytes actively allocated by the MLX allocator (excludes cache).
fn get_active_memory() -> usize;

/// Peak active bytes seen since process start or last reset.
fn get_peak_memory() -> usize;

/// Bytes held in the allocator's free-buffer cache (0 on CPU-only backend).
fn get_cache_memory() -> usize;

/// Set the soft allocator memory limit in bytes. Returns previous limit.
fn set_memory_limit(limit: usize) -> usize;

/// Get the current soft allocator memory limit in bytes.
fn get_memory_limit() -> usize;

/// Set the allocator cache limit in bytes. Returns previous limit.
/// On the no-gpu CPU backend this is a no-op that returns 0.
fn set_cache_limit(limit: usize) -> usize;

/// Reset the recorded peak memory counter to 0.
fn reset_peak_memory();

/// Get max GPU memory size (works across Metal and CUDA backends)
fn gpu_max_memory_size() -> usize;

Expand Down Expand Up @@ -2514,6 +2545,11 @@ pub mod dtype;
// Public so that mlxcel (the main crate) can log hardware info at startup.
pub mod hardware;

// Typed wrappers around MLX's runtime memory accounting APIs (issue #55).
// Public so that the CLI generate path can surface post-load resident
// memory and the preflight (#56) can call `set_memory_limit` to fail fast.
pub mod memory;

// RoPE variants that are not exposed directly by `mlx::core::fast::rope`.
// Currently: proportional RoPE used by Gemma 4 full-attention layers.
pub mod rope_proportional;
Expand Down
Loading