Add support for x86 SIMD (AVX2)#3019
Conversation
| #if !defined(MLX_USE_ACCELERATE) | ||
| #if defined(__AVX512F__) | ||
| #include "mlx/backend/cpu/simd/avx512_simd.h" | ||
| #elif defined(__AVX2__) | ||
| #include "mlx/backend/cpu/simd/avx_simd.h" | ||
| #elif defined(__SSE4_2__) | ||
| #include "mlx/backend/cpu/simd/sse_simd.h" | ||
| #endif | ||
| #endif |
There was a problem hiding this comment.
I'm wondering if this will break our linux x86 distribution in some cases. If we build with avx512 then someone tries to run it on a machine which doesn't support avx512 it will crash right?
There was a problem hiding this comment.
Actually it looks like just the lowest level is enabled by default. So we should be ok.
|
@dhiltgen what are you thinking for next steps here? I might suggest we split this out into multiple PRs to make it easier to review and incorporate. The first PR could be the basic SSE backend for X86 which we should definitely integrate. Following that we could add the extra back-ends (there is a question of how to tests those as well). We will probably also want a neon-only back-end for linux ARM (i.e. no through accelerate). |
|
Splitting up to smaller chunks sounds like a reasonable approach. I'll probably keep this in draft for a bit, while we focus on full GPU load for best performance. |
|
Sounds good! |
|
I've updated this branch with a more focused implementation targeting just AVX2, fleshed out to provide a real-world performance boost for mlx_lm models running on the CPU. |
|
I think lots of changes can be submitted as separate PRs, for example the JIT compiler and allocator changes, which we can merge in a much faster manner. |
Integrate BufferCache into the CPU allocator to enable memory reuse for CPU-only builds. Previously the no_gpu allocator called malloc/free on every allocation with no caching, while the Metal and CUDA backends had buffer caching for better performance. Track cached buffers by their physical capacity when they are reused so get_cache_memory(), active memory, and cache limit enforcement continue to reflect retained memory. Add a regression test for reusing a larger cached block for a smaller request. Changes: - Add CpuCachedBuffer struct with intrusive freelist for object pooling - Use BufferCache to recycle freed buffers with a 32MB default cache limit - Preserve cached block capacity across reuse and avoid caching zero-size allocations - Implement get_cache_memory(), set_cache_limit(), clear_cache() (were no-ops) - Cache-first allocation path with fallback to OS malloc on cache miss
Leak the IO ThreadPool singletons and CPU CompilerCache using the same process-lifetime pattern already used by the Scheduler singleton. The CompilerCache owns dlopen handles for JIT shared libraries. Destroying it during static teardown can dlclose generated code while stream worker threads may still be winding down. The IO loader thread pools have the same shutdown-order risk on Windows CRT teardown. These objects are process-lifetime infrastructure, and the OS reclaims their resources at exit. Changes: - Leak CompilerCache so JIT libraries remain mapped through process exit - Leak IO ThreadPool singletons to avoid teardown-order races - Clarify the Scheduler singleton comment that documents this pattern
Enable CPU mx.compile() on Windows by detecting and using clang-cl bundled with Visual Studio, or MSVC cl.exe, for JIT compilation. Keep GPU compile availability independent from the CPU compiler probe so CPU+GPU builds do not disable GPU mx.compile() when a host C++ compiler is unavailable. Changes: - Add clang-cl detection via vswhere and prefer a compiler matching the build toolchain - Add JitCompiler::available() to probe CPU JIT availability - Emit and load .dll JIT libraries on Windows - Support both MSVC and GCC/Clang preamble generation scripts, including optional SIMD flags - Use WIN32 shell detection and pass preamble SIMD flags through CMake - Define NOMINMAX/WIN32_LEAN_AND_MEAN on all WIN32 compilers
Add an AVX2 SIMD backend, CPU thread pool, and vectorized implementations for the major CPU operations used by CPU inference on x86_64. This makes x86 CPU inference practical for small models and substantially improves CPU throughput versus the scalar baseline. Exact speedups depend on model, quantization, prompt/generation mix, BLAS implementation, and CPU power profile, so benchmark details belong in the PR notes rather than the commit message. SIMD foundation: - avx_simd.h: Simd<T,8> for float/double/int/float16/bfloat16 with F16C conversion, comparisons, and reductions - x86_simd_macros.h: comparison predicates and boolean mask operations - base_simd.h: int64/uint64 additions and x86 conditional includes - math.h: x86 special functions with Newton-Raphson refinement Thread pool: - GCD backend on Apple platforms and persistent std::thread backend on Linux/Windows - parallel_for with serialized dispatch and per-worker spin-then-sleep wakeup - Physical-core default thread count, MLX_CPU_THREADS override, and optional OpenBLAS single-thread coordination Vectorized ops: - quantized.cpp/quantized_avx2.h: multi-column dequantize+FMA for Q4/Q8 - norms.cpp: RMSNorm and LayerNorm with SIMD parallel reduction - rope.cpp/rope_avx2.h: AVX2 interleaved sin/cos rotation - sdpa.cpp: tiled Q*K^T with online softmax, threaded across heads - compiled.cpp: SIMD codegen plus parallel dispatch - binary.h, unary.h, copy.cpp, indexing.cpp, reduce.cpp, softmax.cpp, gemms/: SIMD and threading improvements throughout
|
@zcbenz I've split a few pieces out of this one and rebased it so it's ready for another look. |
Proposed changes
Implement AVX2 SIMD support for better performance on CPU-only x86 systems. Quantized matmul leveraging int8 maddubs for 4-bit/8-bit weights, with FP4, and FP8 support. Fast implementations for SDPA, RoPE, Norms, softmax and reduce. Threadpool coordination with OpenBLAS/GCD to utilize all CPU cores. JIT support for CPU SIMD.
Unless stated otherwise, all benchmarks with
mlx_lm.benchmark -p 2048 -g 128(5 trials, averages reported)Windows 11, AMD Ryzen 9 7950X (Zen 4)
4-bit Quantized
8-bit Quantized
bf16 (Unquantized)
vs Upstream MLX (unoptimized)
Upstream is too slow for p2048/g128, so both sides use
p16/g4Linux, Intel Core i7-11700K @ 3.60GHz (Rocket Lake)
4-bit Quantized
8-bit Quantized
bf16 (Unquantized)
MacOS 26.0, M3 Max (CPU only build)
Not the focus of this PR, but to demonstrate a net improvement due to the threading addition.
4-bit Quantized
8-bit Quantized
bf16 (Unquantized)
vs Upstream MLX
Shorter settings used.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes