Skip to content

Conversation

@hxu296
Copy link

@hxu296 hxu296 commented Jan 19, 2026

Summary

Adds NAX Split-K GEMM support to address high runtime variance and slow tail cases in large-K matmals on M5 GPU. For GEMMs with very large K dimensions, partitioning work along K substantially reduces variance and eliminates slow tail cases, achieving up to ~1.6× speedup over the fused NAX GEMM.

Changes

  • New kernel steel_gemm_splitk_nax partitions K-dimension work across threadgroups, then accumulates partial sums
  • Dispatch heuristic: batch_size==1 AND M×N >= 2048^2 AND K >= 10240 AND K >= 3×max(M,N) AND NAX available
  • Added a benchmark script splitk_gemm_bench.py to compare fused NAX vs Split-K NAX GEMM performance
  • Refactored NAX gemm_loop parameters to reuse it across fused and split-k paths

Performance on M5

Performance (bfloat16):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    156.11ms    142.52ms      1.10x
 2048  3072  10240    236.23ms    211.46ms      1.12x
 3072  3072  10240    376.75ms    315.02ms      1.20x
 3072  3072  12288    500.96ms    372.27ms      1.35x
 3072  4096  12288    621.92ms    494.85ms      1.26x
 4096  4096  12288    844.15ms    659.48ms      1.28x
 4096  4096  18432   1472.46ms    984.36ms      1.50x
 4096  4096  21504   1781.44ms   1148.95ms      1.55x
 4096  6144  21504   2807.22ms   1734.34ms      1.62x
 6144  6144  21504   4078.63ms   2635.42ms      1.55x

Performance (float16):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    155.20ms    147.03ms      1.06x
 2048  3072  10240    226.05ms    217.93ms      1.04x
 3072  3072  10240    378.34ms    325.91ms      1.16x
 3072  3072  12288    469.55ms    386.22ms      1.22x
 3072  4096  12288    618.64ms    521.50ms      1.19x
 4096  4096  12288    904.95ms    702.39ms      1.29x
 4096  4096  18432   1474.75ms   1079.02ms      1.37x
 4096  4096  21504   1865.51ms   1281.52ms      1.46x
 4096  6144  21504   2755.68ms   1965.28ms      1.40x
 6144  6144  21504   4086.48ms   3017.21ms      1.35x

Performance (float32):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    193.15ms    196.28ms      0.98x
 2048  3072  10240    293.95ms    302.57ms      0.97x
 3072  3072  10240    449.79ms    473.58ms      0.95x
 3072  3072  12288    547.38ms    566.81ms      0.97x
 3072  4096  12288    727.40ms    758.03ms      0.96x
 4096  4096  12288    989.01ms   1014.35ms      0.98x
 4096  4096  18432   1808.03ms   1629.34ms      1.11x
 4096  4096  21504   2692.04ms   1901.86ms      1.42x
 4096  6144  21504   4260.02ms   2867.28ms      1.49x
 6144  6144  21504   6762.67ms   4345.39ms      1.56x

Design Notes

  1. Some dispatch thresholds (i.e. M×N >= 2048^2 AND K >= 10240) and split_k_partition_size (3072) are empirically determined on M5 GPU. Future work needed to tune for additional NAX devices as they become available.
  2. Added MLX_DISABLE_SPLITK_NAX env var for benchmarking purposes to compare against regular GEMM. Can be removed before merging.

Fixes #3017

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@hxu296 hxu296 changed the title Splitk nax pr Add NAX Split-K GEMM for large-K matrix multiplications Jan 19, 2026
@hxu296 hxu296 changed the title Add NAX Split-K GEMM for large-K matrix multiplications Add NAX Split-K GEMM for large-K matmuls to improve performance on M5 Jan 19, 2026
@hxu296 hxu296 changed the title Add NAX Split-K GEMM for large-K matmuls to improve performance on M5 Add NAX Split-K GEMM for large-K matmuls to improve performance Jan 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Enhancement] Investigate NAX Split-K for large-K GEMM stability on M5

1 participant