Skip to content

[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502

Open
acsweet wants to merge 4 commits intoml-explore:mainfrom
acsweet:simd-backend-avx
Open

[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502
acsweet wants to merge 4 commits intoml-explore:mainfrom
acsweet:simd-backend-avx

Conversation

@acsweet
Copy link
Copy Markdown

@acsweet acsweet commented May 9, 2026

Proposed changes

This PR adds an AVX SIMD backend for fp16 and bf16 matmul (GEMM and GEMV) on CPU for Linux. Follows from the discussion in #2037, and is a precursor to adding the full set of AVX SIMD instructions in a follow-up PR. Let me know what you think, I'd appreciate any feedback (including adjustments to benchmarking methodology).

I modified the bench_gemm.py and bench_gemv.py in benchmarks/python/blas so they'd complete in a reasonable amount of time. I ran them with a build of mlx from this PR and against the official mlx release for comparison. Note I left out the other dtypes from the benchmarked results printed below due to potential build differences (could be an error on my part). I built mlx with:

CMAKE_ARGS="-DMLX_BUILD_CPU=ON -DMLX_BUILD_CUDA=OFF -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .

Bench setup

  • OS: Arch Linux, kernel 6.18.9-arch1-2 x86_64
  • CPU: Intel Core i7-10700K
  • mlx baseline: mlx-cpu==0.31.2
  • torch comparison: torch==2.5.1+cpu
  • benchmark commands:
python benchmarks/python/blas/bench_gemm.py --quick --verbose --single-threaded
python benchmarks/python/blas/bench_gemv.py --quick --verbose --single-threaded

Bench results

GEMM - branch (this PR)

B M N K dtype t torch_gf mlx_gf diff
16 234 768 3072 float16 nn 1.510 93.479 +6091.42%
1 1024 1024 2048 float16 nn 1.103 82.306 +7362.47%
16 234 768 3072 float16 nt 2.319 91.380 +3840.00%
1 1024 1024 2048 float16 nt 2.318 81.883 +3431.84%
16 234 768 3072 float16 tn 4.056 95.101 +2244.59%
1 1024 1024 2048 float16 tn 4.073 83.882 +1959.25%

GEMM - mlx-cpu==0.31.2

B M N K dtype t torch_gf mlx_gf diff
16 234 768 3072 float16 nn 1.623 3.641 +124.27%
1 1024 1024 2048 float16 nn 1.511 3.645 +141.26%
16 234 768 3072 float16 nt 2.320 3.884 +67.40%
1 1024 1024 2048 float16 nt 2.317 3.878 +67.42%
16 234 768 3072 float16 tn 4.067 3.532 -13.15%
1 1024 1024 2048 float16 tn 4.080 3.459 -15.22%

GEMV - branch (this PR)

============================================================
gemv | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   8.92 GB/s, torch=   1.66 GB/s, diff=+436.2%
  in=  512, out= 1024, mlx=  21.82 GB/s, torch=   1.89 GB/s, diff=+1055.4%
  in=  512, out= 4096, mlx=  26.47 GB/s, torch=   1.88 GB/s, diff=+1306.6%
  in=  512, out=11008, mlx=  16.17 GB/s, torch=   1.63 GB/s, diff=+892.2%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=  26.23 GB/s, torch=   1.83 GB/s, diff=+1330.2%
  in= 2048, out= 1024, mlx=  30.29 GB/s, torch=   1.53 GB/s, diff=+1882.1%
  in= 2048, out= 4096, mlx=  20.94 GB/s, torch=   1.85 GB/s, diff=+1033.7%
  in= 2048, out=11008, mlx=  18.57 GB/s, torch=   1.59 GB/s, diff=+1070.8%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=  12.20 GB/s, torch=   1.23 GB/s, diff=+895.1%
  in= 1024, out=  512, mlx=  29.48 GB/s, torch=   1.75 GB/s, diff=+1585.6%
  in= 4096, out=  512, mlx=  25.17 GB/s, torch=   1.66 GB/s, diff=+1413.1%
  in=11008, out=  512, mlx=  39.61 GB/s, torch=   1.75 GB/s, diff=+2167.0%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=  20.01 GB/s, torch=   2.21 GB/s, diff=+803.5%
  in= 1024, out= 2048, mlx=  34.55 GB/s, torch=   2.29 GB/s, diff=+1410.6%
  in= 4096, out= 2048, mlx=  16.50 GB/s, torch=   2.08 GB/s, diff=+692.2%
  in=11008, out= 2048, mlx=  19.15 GB/s, torch=   1.91 GB/s, diff=+900.7%


============================================================
gemv_t | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   6.97 GB/s, torch=   1.77 GB/s, diff=+294.4%
  in=  512, out= 1024, mlx=  13.19 GB/s, torch=   0.95 GB/s, diff=+1290.3%
  in=  512, out= 4096, mlx=  15.76 GB/s, torch=   0.81 GB/s, diff=+1839.1%
  in=  512, out=11008, mlx=  12.32 GB/s, torch=   0.95 GB/s, diff=+1193.0%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=  10.15 GB/s, torch=   0.85 GB/s, diff=+1099.1%
  in= 2048, out= 1024, mlx=  13.90 GB/s, torch=   0.87 GB/s, diff=+1499.0%
  in= 2048, out= 4096, mlx=  12.03 GB/s, torch=   0.55 GB/s, diff=+2090.4%
  in= 2048, out=11008, mlx=  16.71 GB/s, torch=   1.43 GB/s, diff=+1066.2%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=  11.21 GB/s, torch=   2.00 GB/s, diff=+460.8%
  in= 1024, out=  512, mlx=  14.76 GB/s, torch=   1.17 GB/s, diff=+1161.1%
  in= 4096, out=  512, mlx=  17.52 GB/s, torch=   1.16 GB/s, diff=+1412.8%
  in=11008, out=  512, mlx=  15.96 GB/s, torch=   1.07 GB/s, diff=+1385.2%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=  16.26 GB/s, torch=   1.27 GB/s, diff=+1180.0%
  in= 1024, out= 2048, mlx=  22.60 GB/s, torch=   1.41 GB/s, diff=+1502.3%
  in= 4096, out= 2048, mlx=  16.50 GB/s, torch=   0.58 GB/s, diff=+2748.3%
  in=11008, out= 2048, mlx=  18.33 GB/s, torch=   0.43 GB/s, diff=+4210.5%

GEMV - mlx-cpu==0.31.2

============================================================
gemv | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   1.06 GB/s, torch=   2.13 GB/s, diff=-50.2%
  in=  512, out= 1024, mlx=   1.15 GB/s, torch=   2.23 GB/s, diff=-48.5%
  in=  512, out= 4096, mlx=   1.17 GB/s, torch=   2.24 GB/s, diff=-48.1%
  in=  512, out=11008, mlx=   1.11 GB/s, torch=   2.20 GB/s, diff=-49.6%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=   0.94 GB/s, torch=   2.22 GB/s, diff=-57.7%
  in= 2048, out= 1024, mlx=   0.97 GB/s, torch=   2.25 GB/s, diff=-57.0%
  in= 2048, out= 4096, mlx=   0.97 GB/s, torch=   2.20 GB/s, diff=-56.2%
  in= 2048, out=11008, mlx=   0.96 GB/s, torch=   2.20 GB/s, diff=-56.5%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=   1.06 GB/s, torch=   2.01 GB/s, diff=-47.3%
  in= 1024, out=  512, mlx=   1.06 GB/s, torch=   2.25 GB/s, diff=-53.0%
  in= 4096, out=  512, mlx=   0.83 GB/s, torch=   2.26 GB/s, diff=-63.5%
  in=11008, out=  512, mlx=   0.58 GB/s, torch=   2.21 GB/s, diff=-73.6%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=   1.18 GB/s, torch=   2.15 GB/s, diff=-45.1%
  in= 1024, out= 2048, mlx=   1.08 GB/s, torch=   2.25 GB/s, diff=-51.8%
  in= 4096, out= 2048, mlx=   0.83 GB/s, torch=   2.25 GB/s, diff=-63.3%
  in=11008, out= 2048, mlx=   0.58 GB/s, torch=   2.20 GB/s, diff=-73.5%

============================================================
gemv_t | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   1.01 GB/s, torch=   1.70 GB/s, diff=-40.6%
  in=  512, out= 1024, mlx=   0.90 GB/s, torch=   1.42 GB/s, diff=-36.4%
  in=  512, out= 4096, mlx=   0.92 GB/s, torch=   1.57 GB/s, diff=-41.2%
  in=  512, out=11008, mlx=   0.96 GB/s, torch=   1.90 GB/s, diff=-49.3%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=   0.87 GB/s, torch=   1.40 GB/s, diff=-38.1%
  in= 2048, out= 1024, mlx=   0.81 GB/s, torch=   1.42 GB/s, diff=-43.1%
  in= 2048, out= 4096, mlx=   0.56 GB/s, torch=   0.68 GB/s, diff=-18.4%
  in= 2048, out=11008, mlx=   0.78 GB/s, torch=   1.55 GB/s, diff=-49.6%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=   1.04 GB/s, torch=   2.09 GB/s, diff=-50.0%
  in= 1024, out=  512, mlx=   0.87 GB/s, torch=   1.44 GB/s, diff=-39.1%
  in= 4096, out=  512, mlx=   0.70 GB/s, torch=   1.42 GB/s, diff=-50.4%
  in=11008, out=  512, mlx=   0.39 GB/s, torch=   1.18 GB/s, diff=-66.7%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=   0.92 GB/s, torch=   1.53 GB/s, diff=-40.0%
  in= 1024, out= 2048, mlx=   0.88 GB/s, torch=   1.57 GB/s, diff=-44.2%
  in= 4096, out= 2048, mlx=   0.49 GB/s, torch=   0.58 GB/s, diff=-14.9%
  in=11008, out= 2048, mlx=   0.31 GB/s, torch=   0.41 GB/s, diff=-25.2%

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 10, 2026

@dhiltgen I remember ollama was doing something similar? Can you please check if this would live together with your work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants