Fix excessive memory allocation for static-shape attention ops#2636
Fix excessive memory allocation for static-shape attention ops#2636Pranaykarvi wants to merge 1 commit intoapple:mainfrom
Conversation
|
Hi @TobyRoseman , just a gentle follow-up in case this slipped through. |
TobyRoseman
left a comment
There was a problem hiding this comment.
Your new unit tests don't pass.
| from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target | ||
| from coremltools.converters.mil.mil import Builder as mb | ||
| from coremltools.converters.mil.mil import types | ||
| from coremltools.converters.mil.mil.types.symbolic import is_symbolic |
There was a problem hiding this comment.
I don't think you need this line.
| # allocation issues as static shapes, so the higher threshold is appropriate. | ||
| logger.debug( | ||
| f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}" | ||
| f"skipping SDPA op, Q seq_length is dynamic (symbolic), " |
There was a problem hiding this comment.
This shouldn't be a f-string since there is no variable being used.
There was a problem hiding this comment.
Looks like this is also an issue in several other places of this PR.
| "common::remove_symbolic_reshape", | ||
| "common::noop_elimination", | ||
| # Apply attention slicing early to reduce memory allocation for static sequence lengths. | ||
| # This pass replaces scaled_dot_product_attention with a memory-efficient sliced implementation. |
There was a problem hiding this comment.
Remove this line of the comment. It doesn't really add much and can easily become outdated/inaccurate.
| Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider) | ||
| """ | ||
|
|
||
| # Default threshold for dynamic-shape models. Dynamic shapes use runtime allocation |
There was a problem hiding this comment.
The added comments in this file are far too long. They need to be much concise.
| 3. The model can be converted successfully | ||
| """ | ||
| # Create a minimal transformer attention block with static seq_len=128 | ||
| batch_size = 1 |
There was a problem hiding this comment.
These are constants? If so the variable name should be all in caps.
| return output | ||
|
|
||
| # Apply the default pass pipeline which includes the slicing pass | ||
| from coremltools.converters.mil.mil.passes.pass_pipeline import PassPipeline |
There was a problem hiding this comment.
Import statement should be done at the top of the file. If for some reason, they can't be done at the top of the file, do it at the top of the function.
There was a problem hiding this comment.
Looks like this is also an issue elsewhere.
| f"This indicates the memory allocation fix is not working correctly." | ||
| ) | ||
|
|
||
| # Verify the program structure is correct |
There was a problem hiding this comment.
Did you mean to delete this comment?
| """ | ||
| Regression test for memory allocation bug with static sequence length transformers. | ||
|
|
||
| This test verifies that exporting a Llama-style Transformer with a static sequence |
There was a problem hiding this comment.
I might be wrong, but I don't think any of this is specific to a Llama-style Transformer.
| # The key verification is that attention ops are sliced and tensor sizes are reasonable | ||
| # which we've already checked above | ||
|
|
||
| def test_static_seq_len_128_with_quantization(self): |
There was a problem hiding this comment.
There is a lot of duplicated code here with the previous method. Please create a helper function.
Summary
This PR fixes excessive memory allocation for Transformer attention ops when the
sequence length is statically known at compile time (e.g.
seq_len=128).For static-shape attention, Metal may eagerly allocate large intermediate buffers
(e.g. QKᵀ matrices), which can lead to multi-GB allocations and OOM on iOS devices.
The existing attention slicing pass was gated behind a high sequence-length
threshold and did not trigger for smaller static shapes.
This change enables memory-efficient attention slicing for static sequence lengths
while preserving the existing behavior for dynamic-shape models.
Problem
When exporting Transformer models with a statically known sequence length,
scaled_dot_product_attentionmay materialize large intermediate tensors duringlowering. On iOS, this can result in excessive Metal buffer allocation (observed
~10GB) and OOM during inference or benchmarking, even for relatively small models
(e.g. Llama-style models with
seq_len=128).Solution
scaled_dot_product_attention_sliced_qpass to break the computation intosmaller chunks and reduce peak memory usage.
1280) and behavior fordynamic-shape models to avoid unnecessary overhead.
This approach limits the change to the pathological static-shape case and avoids
global behavior changes.
Testing
a static sequence length (
seq_len=128).pathological buffer materialization.
Notes
eager buffer allocation.
Fixes #2590.