Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4/M5).
O(N) memory instead of O(N²), enabling 100K+ sequence lengths on unified memory.
Benchmarked on Apple Silicon (M1/M2/M3/M4/M5):
| Seq Length | vs PyTorch SDPA | Notes |
|---|---|---|
| 1024 | 1.1-2.0x faster | Crossover point |
| 2048 | 1.7-3.7x faster | Sweet spot |
| 4096 | 2.0-3.9x faster | Peak performance |
| 8192+ | 3-4x faster | SDPA often OOMs |
Average speedup: 1.8x across all configurations.
pip install mps-flash-attngit clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention
# Build Swift bridge
cd swift-bridge && swift build -c release && cd ..
# Install (uses the torch already in your environment)
pip install -e . --no-build-isolationTo build the full PyPI matrix (5 Python versions, dual torch ABIs), use
scripts/build_all_wheels.sh. The wheels land in dist/. Run
bash tests/wheel_matrix/test_wheel_matrix.sh to verify each wheel against
the torch versions it claims to support.
from mps_flash_attn import flash_attention
# (B, H, N, D) format
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
out = flash_attention(q, k, v)out = flash_attention(q, k, v, is_causal=True)# Only attend to last 4096 tokens
out = flash_attention(q, k, v, is_causal=True, window_size=4096)from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8
# Quantize K/V to FP8
k_quant, k_scale = quantize_kv_fp8(k)
v_quant, v_scale = quantize_kv_fp8(v)
# Run attention with quantized KV
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)from mps_flash_attn import flash_attention_chunked
# Process 100K tokens without OOM
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
out = flash_attention_chunked(q, k, v, chunk_size=8192)from mps_flash_attn import replace_sdpa
replace_sdpa() # Patches F.scaled_dot_product_attention
# Now all PyTorch attention uses Flash Attention on MPSfrom mps_flash_attn import register_custom_op
register_custom_op()
@torch.compile
def my_attention(q, k, v):
return torch.ops.mfa.flash_attention(q, k, v, False, None, None)out = flash_attention(q, k, v, bf16_backward=True) # 2x faster backward
loss = out.sum()
loss.backward()# Quick benchmark
python -m mps_flash_attn.benchmark --suite quick
# Full suite with report
python -m mps_flash_attn.benchmark --suite full --output report.htmlfrom mps_flash_attn.benchmark import run_suite, compare_vs_sdpa
results = run_suite(seq_lengths=[1024, 2048, 4096])
compare_vs_sdpa()| Feature | Supported | Notes |
|---|---|---|
| Forward pass | yes | FP16/BF16/FP32 |
| Backward pass | yes | Full gradient support |
| Causal masking | yes | Native kernel support |
| Attention masks | yes | Boolean masks |
| Sliding window | yes | For local attention models |
| GQA/MQA | yes | Grouped-query attention |
| Quantized KV | yes | FP8, INT8, NF4 |
| Chunked attention | yes | 100K+ tokens |
| torch.compile() | yes | Custom op backend |
| Dropout | no | Not supported |
Python API (mps_flash_attn)
|
C++ Extension (_C_legacy.so or _C_modern.so, picked at import)
| dlopen
Swift Bridge (MFABridge.swift)
|
Metal Flash Attention (kernel generation)
|
Metal GPU Shaders
- macOS 14+ (Sonoma, Sequoia, or Tahoe)
- Apple Silicon (M1/M2/M3/M4/M5)
- Python 3.10, 3.11, 3.12, 3.13, or 3.14
- PyTorch 2.5 through 2.11
| torch 2.5 | 2.6 | 2.7 | 2.8 | 2.9 | 2.10 | 2.11 | |
|---|---|---|---|---|---|---|---|
| py 3.10 | OK | OK | OK | OK | OK | OK | OK |
| py 3.11 | OK | OK | OK | OK | OK | OK | OK |
| py 3.12 | OK | OK | OK | OK | OK | OK | OK |
| py 3.13 | OK | OK | OK | OK | OK | OK | |
| py 3.14 | OK | OK | OK |
Empty cells are combinations PyTorch itself does not ship wheels for.
The dependencies field in the wheel pins torch>=2.5,<2.12, so pip install mps-flash-attn will refuse to install on a torch outside this range rather
than installing and crashing at runtime. When a new torch minor is released we
cut a release that bumps the upper bound.
- metal-flash-attention by Philip Turner
- Flash Attention paper by Tri Dao et al.
MIT