Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions src/commands/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,7 @@ fn log_estimate_vs_actual_delta(est: &MemoryEstimate, snap: &mlxcel_core::memory

let est_bytes = est.total_bytes;
let actual = snap.active_bytes;
let (delta_label, delta_bytes) = if actual >= est_bytes {
("over-estimated by", actual.saturating_sub(est_bytes))
} else {
("under-estimated by", est_bytes.saturating_sub(actual))
};
let (delta_label, delta_bytes) = estimate_delta_label_and_bytes(est_bytes, actual);
let ratio = if est_bytes > 0 {
actual as f64 / est_bytes as f64
} else {
Expand All @@ -239,6 +235,19 @@ fn log_estimate_vs_actual_delta(est: &MemoryEstimate, snap: &mlxcel_core::memory
);
}

fn estimate_delta_label_and_bytes(estimate: u64, actual: u64) -> (&'static str, u64) {
if actual >= estimate {
("under-estimated by", actual.saturating_sub(estimate))
} else {
("over-estimated by", estimate.saturating_sub(actual))
}
}

fn memory_preflight_ctx_len(prompt_tokens: usize, max_tokens: usize) -> u64 {
let total = prompt_tokens.saturating_add(max_tokens).max(1);
u64::try_from(total).unwrap_or(u64::MAX)
}

/// Run the `--estimate-memory` preflight for `mlxcel generate`.
///
/// Returns `Some(estimate)` when the user passed `--estimate-memory`
Expand All @@ -251,28 +260,27 @@ fn log_estimate_vs_actual_delta(est: &MemoryEstimate, snap: &mlxcel_core::memory
/// figure and the override flags. Always prints the formatted
/// breakdown before aborting so operators can see the same byte
/// table `mlxcel inspect` would have shown.
fn run_memory_preflight(args: &GenerateArgs) -> Result<Option<MemoryEstimate>> {
fn run_memory_preflight(
args: &GenerateArgs,
prompt_token_count: usize,
) -> Result<Option<MemoryEstimate>> {
if !args.generation.estimate_memory {
return Ok(None);
}

// Derive int8 KV from the existing --cache-type-k / --cache-type-v
// pair so the preflight reflects what the loaded cache will
// actually allocate. Mixed-precision configurations fall back to
// FP16 sizing because the size formula does not model them
// directly — surfaced in the printed breakdown.
let kv_int8 = matches!(
(
args.generation.turbo.cache_type_k.as_deref(),
args.generation.turbo.cache_type_v.as_deref(),
),
(Some("int8"), Some("int8")) | (Some("i8"), Some("i8"))
);
let kv_cache_mode = resolve_kv_cache_mode(
args.generation.turbo.cache_type_k.as_deref(),
args.generation.turbo.cache_type_v.as_deref(),
args.generation.turbo.kv_cache_mode.as_deref(),
)
.map_err(|e| anyhow::anyhow!("{}", e))?;
let kv_int8 = matches!(kv_cache_mode, KVCacheMode::Int8);

// Use the user's `--max-tokens` as the KV ctx_len input. This
// matches the way `mlxcel inspect --max-tokens N` sizes the KV
// estimate, so the preflight and the inspect view never disagree.
let ctx_len = args.generation.max_tokens.max(1) as u64;
// Size the KV cache for the tokens that can actually enter the cache:
// rendered prompt tokens plus the requested decode budget. This still runs
// before model load, but after tokenizer/template processing has made the
// prompt length knowable.
let ctx_len = memory_preflight_ctx_len(prompt_token_count, args.generation.max_tokens);

let estimate =
estimate_total_memory(&args.model.model, ctx_len, 1, QuantHint::Default, kv_int8);
Expand Down Expand Up @@ -1052,15 +1060,6 @@ pub(crate) fn run_generate(args: GenerateArgs) -> Result<()> {
}
}

// Memory preflight (issue #56). Runs the unified estimator and
// aborts when total > available. Skipped when --estimate-memory
// was not passed. --force / --no-memory-check downgrades the
// abort to a warning. Sub-issue C's `MLXCEL_MEMORY_LIMIT` env hook
// is honoured transparently by the estimator's
// `resolve_available_memory` step (MLX allocator soft cap wins
// over OS RAM when nonzero).
let preflight_estimate = run_memory_preflight(&args)?;

let pipeline_requested = cli_pipeline_requested(&args);
let tokenizer = load_tokenizer(&args.model.model)?;
let prompt = load_cli_prompt(
Expand All @@ -1071,6 +1070,11 @@ pub(crate) fn run_generate(args: GenerateArgs) -> Result<()> {
);
let mut prompt_tokens = tokenize_prompt(&tokenizer, &prompt)?;

// Memory preflight (issue #56). Runs after prompt rendering/tokenization so
// long prompts are included in the KV-cache budget, but still before the
// model weights are loaded.
let preflight_estimate = run_memory_preflight(&args, prompt_tokens.len())?;

let sampling_config =
build_cli_sampling_config(&args, mlxcel::read_eos_token_ids(&args.model.model));

Expand Down
25 changes: 22 additions & 3 deletions src/commands/generate_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.

use super::{
apply_user_chat_template, cli_pipeline_requested, generated_suffix,
generation_stats_from_duration, resolve_cli_pipeline_assignments, resolve_cli_prompt,
validate_pipeline_parallel_args, validate_tensor_parallel_args,
apply_user_chat_template, cli_pipeline_requested, estimate_delta_label_and_bytes,
generated_suffix, generation_stats_from_duration, memory_preflight_ctx_len,
resolve_cli_pipeline_assignments, resolve_cli_prompt, validate_pipeline_parallel_args,
validate_tensor_parallel_args,
};
use mlxcel::server::chat_template::ChatTemplateProcessor;
use std::fs;
Expand All @@ -40,6 +41,24 @@ fn generation_stats_from_duration_handles_zero_elapsed_time() {
assert_eq!(stats.decode_tok_per_sec, 0.0);
}

#[test]
fn estimate_delta_labels_match_actual_direction() {
assert_eq!(
estimate_delta_label_and_bytes(100, 125),
("under-estimated by", 25)
);
assert_eq!(
estimate_delta_label_and_bytes(125, 100),
("over-estimated by", 25)
);
}

#[test]
fn memory_preflight_ctx_len_includes_prompt_and_generation_budget() {
assert_eq!(memory_preflight_ctx_len(4096, 128), 4224);
assert_eq!(memory_preflight_ctx_len(0, 0), 1);
}

#[test]
fn apply_user_chat_template_wraps_prompt_as_user_message() {
let processor = ChatTemplateProcessor::with_template(
Expand Down
21 changes: 9 additions & 12 deletions src/commands/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

use anyhow::{Result, anyhow};

use mlxcel::cli::turbo_args::resolve_kv_cache_mode;
use mlxcel::memory_estimate::{QuantHint, estimate_total_memory, format_estimate};
use mlxcel_core::cache::KVCacheMode;

use crate::InspectArgs;

Expand All @@ -40,18 +42,13 @@ pub(crate) fn run_inspect(args: InspectArgs) -> Result<()> {
// Translate the user-facing `--quant` label into the typed hint.
let quant = parse_quant_hint(&args.quant)?;

// Translate the K/V cache flag pair into the int8/fp16 decision the
// estimator understands. Both flags must point at int8 for KV
// bytes to halve; any other combination is treated as fp16 (the
// default) since mixed-precision KV is not directly modelled in
// the size formula. Surface the consequence in the printed output.
let kv_int8 = matches!(
(
args.turbo.cache_type_k.as_deref(),
args.turbo.cache_type_v.as_deref(),
),
(Some("int8"), Some("int8")) | (Some("i8"), Some("i8"))
);
let kv_cache_mode = resolve_kv_cache_mode(
args.turbo.cache_type_k.as_deref(),
args.turbo.cache_type_v.as_deref(),
args.turbo.kv_cache_mode.as_deref(),
)
.map_err(|e| anyhow!("{}", e))?;
let kv_int8 = matches!(kv_cache_mode, KVCacheMode::Int8);

let estimate = estimate_total_memory(&args.model, args.max_tokens, args.batch, quant, kv_int8);

Expand Down
51 changes: 35 additions & 16 deletions src/commands/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! on schema and routing.

use mlxcel::cli::speculative_args::{env_fallback_draft_block_size, env_fallback_draft_kind};
use mlxcel::cli::turbo_args::resolve_kv_cache_mode;
use mlxcel::memory_estimate::{QuantHint, estimate_total_memory, format_bytes, format_estimate};
use mlxcel::server::{
ServerStartupInput, env_fallback_apc_block_size, env_fallback_apc_enabled,
Expand All @@ -31,6 +32,7 @@ use mlxcel::server::{
env_fallback_prompt_cache_max_entries, env_fallback_prompt_cache_min_prefix,
env_fallback_prompt_cache_ttl, env_fallback_reasoning_budget, start_server,
};
use mlxcel_core::cache::KVCacheMode;

/// Run the `mlxcel serve` subcommand.
#[tokio::main]
Expand All @@ -57,24 +59,18 @@ fn run_serve_memory_preflight(args: &crate::ServeArgs) -> anyhow::Result<()> {
return Ok(());
}

let kv_int8 = matches!(
(
args.turbo.cache_type_k.as_deref(),
args.turbo.cache_type_v.as_deref(),
),
(Some("int8"), Some("int8")) | (Some("i8"), Some("i8"))
);
let kv_cache_mode = resolve_kv_cache_mode(
args.turbo.cache_type_k.as_deref(),
args.turbo.cache_type_v.as_deref(),
args.turbo.kv_cache_mode.as_deref(),
)
.map_err(|e| anyhow::anyhow!("{}", e))?;
let kv_int8 = matches!(kv_cache_mode, KVCacheMode::Int8);

// `--ctx-size 0` is the "use model default" sentinel; in that
// case we fall back to 8192 to match the historical sizing used
// by `--recommend-quant`.
let ctx_len = if args.ctx_size > 0 {
args.ctx_size as u64
} else {
mlxcel::memory_estimate::DEFAULT_CTX_LEN
};
let ctx_len = serve_preflight_ctx_len(args);
let batch = serve_preflight_batch(args);

let estimate = estimate_total_memory(&args.model, ctx_len, 1, QuantHint::Default, kv_int8);
let estimate = estimate_total_memory(&args.model, ctx_len, batch, QuantHint::Default, kv_int8);

let banner = format_estimate(&args.model, &estimate);
println!("{banner}");
Expand All @@ -101,6 +97,29 @@ fn run_serve_memory_preflight(args: &crate::ServeArgs) -> anyhow::Result<()> {
Ok(())
}

fn serve_preflight_ctx_len(args: &crate::ServeArgs) -> u64 {
// `--ctx-size 0` is the "use model default" sentinel; in that case we
// fall back to 8192 to match the historical sizing used by
// `--recommend-quant`. `--max-kv-size` caps the plain KV cache length.
let mut ctx_len = if args.ctx_size > 0 {
args.ctx_size as u64
} else {
mlxcel::memory_estimate::DEFAULT_CTX_LEN
};
if args.max_kv_size > 0 {
ctx_len = ctx_len.min(args.max_kv_size as u64);
}
ctx_len.max(1)
}

fn serve_preflight_batch(args: &crate::ServeArgs) -> u64 {
if args.no_batch {
return 1;
}
let active_sequences = args.max_batch_size.unwrap_or(args.n_parallel).max(1);
u64::try_from(active_sequences).unwrap_or(u64::MAX)
}

fn build_startup_input(mut args: crate::ServeArgs) -> anyhow::Result<ServerStartupInput> {
// Translate `--turbo-boundary-v` into the `MLXCEL_KV_BOUNDARY_V_LAYERS`
// env var before any caller of `mlxcel-core` constructs a cache.
Expand Down
32 changes: 31 additions & 1 deletion src/commands/serve_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use std::path::PathBuf;

use super::build_startup_input;
use super::{build_startup_input, serve_preflight_batch, serve_preflight_ctx_len};

fn sample_args() -> crate::ServeArgs {
crate::ServeArgs {
Expand Down Expand Up @@ -152,6 +152,36 @@ fn build_startup_input_preserves_edge_flags_for_normalization() {
assert_eq!(input.decode_storage_backend, None);
}

#[test]
fn serve_preflight_batch_uses_max_batch_size_when_batching_enabled() {
let args = sample_args();
assert_eq!(serve_preflight_batch(&args), 4);
}

#[test]
fn serve_preflight_batch_falls_back_to_parallelism_and_honors_no_batch() {
let mut args = sample_args();
args.max_batch_size = None;
assert_eq!(serve_preflight_batch(&args), 3);

args.no_batch = true;
assert_eq!(serve_preflight_batch(&args), 1);
}

#[test]
fn serve_preflight_ctx_len_uses_default_and_max_kv_cap() {
let mut args = sample_args();
args.ctx_size = 0;
assert_eq!(
serve_preflight_ctx_len(&args),
mlxcel::memory_estimate::DEFAULT_CTX_LEN
);

args.ctx_size = 8192;
args.max_kv_size = 2048;
assert_eq!(serve_preflight_ctx_len(&args), 2048);
}

#[test]
fn build_startup_input_propagates_decode_storage_backend() {
let mut args = sample_args();
Expand Down
Loading