Add NAX Split-K GEMM for large-K matmuls to improve performance #3018
+578
−20
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Adds NAX Split-K GEMM support to address high runtime variance and slow tail cases in large-K matmals on M5 GPU. For GEMMs with very large K dimensions, partitioning work along K substantially reduces variance and eliminates slow tail cases, achieving up to ~1.6× speedup over the fused NAX GEMM.
Changes
steel_gemm_splitk_naxpartitions K-dimension work across threadgroups, then accumulates partial sumssplitk_gemm_bench.pyto compare fused NAX vs Split-K NAX GEMM performancePerformance on M5
Design Notes
MLX_DISABLE_SPLITK_NAXenv var for benchmarking purposes to compare against regular GEMM. Can be removed before merging.Fixes #3017
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes