fix: tighten memory estimator preflight coverage#68
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up audit for epic #52 after reviewing the merged sub-issue work (#53, #54, #55, #56 and PRs #64-#67). The original architecture was in place, but several CLI/runtime paths could still report a memory estimate that did not match the load that would actually happen.
This PR tightens those gaps so the estimator remains conservative and consistent across
inspect,generate --estimate-memory,serve --estimate-memory, and--recommend-quant.Refs #52.
Review Findings Fixed
estimate_total_memory(..., batch, ...)accepted a batch argument but still used a KV-cache helper hardcoded to batch 1.generate --estimate-memoryestimated KV cache from--max-tokensonly, so long prompts were not counted before model load.serve --estimate-memoryignored serving concurrency/batching knobs and sized KV as batch 1.inspect/serve --estimate-memoryrun before runtime initialization, soMLXCEL_MEMORY_LIMITwas not reflected in available-memory calculation.inspect,generate, andservemanually inferred int8 KV only from split K/V flags, missing the legacy--kv-cache-mode int8path and shared validation behavior.head_dimand common field aliases, underestimating models whose head dimension differs fromhidden_size / num_attention_heads.model.safetensors.index.json::metadata.total_sizeeven when the referenced shards were not the files the loader would use.Changes
kv_cache_params_from_path(..., batch)so batch scales KV bytes.MLXCEL_MEMORY_LIMITdirectly in the estimator before checking the already-applied MLX allocator cap.dim,model_dim,n_heads,n_head,num_kv_heads,n_kv_heads,n_head_kv,multi_query_group_num,head_dim, andhead_size.generate --estimate-memorypreflight after prompt rendering/tokenization but before model load, usingprompt_tokens + max_tokensfor KV context length.inspect,generate, andserveuse the sharedresolve_kv_cache_modehelper.serve --estimate-memoryusectx_size,max_kv_size,max_batch_size,n_parallel, andno_batchto derive the preflight shape.weight_footprint_bytestrust index metadata only after validating shard filenames and existence, sum valid shard headers when total size is missing, and fall back to local safetensors headers for stale indexes.Validation
Passed:
cargo fmt --all -- --checkcargo clippy -p mlxcel --lib --bin mlxcel --tests -- -D warningscargo clippy -p mlxcel-core --lib --tests -- -D warningscargo test -p mlxcel memory_estimate --libcargo test -p mlxcel quant_advisor --libcargo test -p mlxcel --bin mlxcelcargo test -p mlxcel-core weights::testsAdditional broader checks:
cargo test -p mlxcel --libran 2,893 tests and failed onlymodels::sanitize_tests::load_and_sanitize_weights_dequantizes_nvfp4_gemma4_checkpoint. I verified the same test fails onorigin/mainwith the same assertion, so it is not introduced by this PR.cargo test -p mlxcel-core --libran 774 tests and failed two existing environment-sensitive FFI tests:ffi_tests::test_from_bytes_fp16_native_dtypeandffi_tests::test_memory_functions. The focusedweights::testssuite covering this PR passed.Security / Performance Notes
total_size; it must match validated local shards or fall back to actual local safetensors headers.