Skip to content

[CUDA] Guard qmm_naive scale and bias loads at tile boundaries#3509

Open
Lyxot wants to merge 1 commit intoml-explore:mainfrom
Lyxot:fix/cuda-qmm-boundary
Open

[CUDA] Guard qmm_naive scale and bias loads at tile boundaries#3509
Lyxot wants to merge 1 commit intoml-explore:mainfrom
Lyxot:fix/cuda-qmm-boundary

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented May 10, 2026

Fixes #3487

Proposed changes

Fix invalid global reads in the CUDA qmm_naive kernel when the output N dimension only partially fills the current tile. scale/bias loads are now predicated by the existing N-tile predicate.

Reproduce

using script repro_qmm_partial_n.py

import mlx.core as mx

mx.set_default_device(mx.gpu)

M, N, K = 17, 1, 16384
group_size = 64
bits = 4
transpose = True

key = mx.random.key(0)
k1, k2 = mx.random.split(key)

x = (mx.random.normal((M, K), key=k1) / K**0.5).astype(mx.float16)
w = (mx.random.normal((N, K), key=k2) / K**0.5).astype(mx.float16)
w_q, scales, biases = mx.quantize(w, group_size, bits)

print(
    "shapes:",
    f"x={x.shape}",
    f"w_q={w_q.shape}",
    f"scales={scales.shape}",
    f"biases={biases.shape}",
)

y = mx.quantized_matmul(x, w_q, scales, biases, transpose, group_size, bits)
mx.eval(y)
mx.synchronize()

print("completed")

run with:

CUDA_LAUNCH_BLOCKING=1 MLX_USE_CUDA_GRAPHS=0 \
/usr/local/cuda/bin/compute-sanitizer --tool memcheck \
--padding 4096 --error-exitcode 99 \
python repro_qmm_partial_n.py

before:

========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 2 bytes
=========     at void cutlass_gemm::qmm_naive_kernel<(bool)1, (bool)0, (bool)1, cutlass::half_t, cutlass::integer_subbyte<(int)4, (bool)0>, cutlass::half_t, cute::tuple<int, int, int, int>, cute::tuple<cute::C<(int)32>, cute::C<(int)128>, cute::C<(int)64>>, cute::tuple<int, cute::C<(int)1>, int>, cute::tuple<int, cute::C<(int)1>, int>, cute::Layout<cute::tuple<int, cute::tuple<cute::C<(int)64>, int>, int>, cute::tuple<int, cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, int>>, cute::tuple<int, cute::C<(int)1>, int>, cute::TiledMMA<cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::Layout<cute::tuple<cute::C<(int)2>, cute::C<(int)2>, cute::C<(int)1>>, cute::tuple<cute::C<(int)1>, cute::C<(int)2>, cute::C<(int)0>>>, cute::tuple<cute::C<(int)32>, cute::C<(int)32>, cute::C<(int)16>>>>(T7, T8, const T4 *, T9, const T5 *, T10, const T6 *, const T4 *, T11, const unsigned int *, const unsigned int *, T4 *, T12, T13)+0x10d0
=========     by thread (32,0,0) in block (0,0,0)
=========     Access at 0x8d0130c00 is out of bounds
...
Traceback (most recent call last):
  File "~/mlx/repro_qmm_partial_n.py", line 33, in <module>
    mx.eval(y)
RuntimeError: cudaLaunchKernelExC(&config, func, params) failed: unspecified launch failure
========= Program hit cudaErrorLaunchFailure (error 719) due to "unspecified launch failure" on CUDA API call to cudaStreamSynchronize.
========= 
terminate called after throwing an instance of 'std::runtime_error'
  what():  cudaStreamSynchronize(stream_) failed: unspecified launch failure
========= Error: process didn't terminate successfully
========= Target application returned an error
========= ERROR SUMMARY: 42 errors

after:

========= COMPUTE-SANITIZER
shapes: x=(17, 16384) w_q=(1, 2048) scales=(1, 256) biases=(1, 256)
completed
========= ERROR SUMMARY: 0 errors

Copilot AI review requested due to automatic review settings May 10, 2026 08:17

This comment was marked as low quality.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG][CUDA] Illegal memory access in QMM when M/N is not multiple of tile size

3 participants