Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b8fd401
block_fp8: mode plumbing + 2D-block-aware validator
yohann-bearzi May 25, 2026
7fb220d
block_fp8: mode plumbing + 2D-block-aware validator + dim math
yohann-bearzi May 25, 2026
d9bf9db
block_fp8: minimal qmv_fast kernel
yohann-bearzi May 25, 2026
1b856df
block_fp8: gather_qmv_fast for MoE decode
yohann-bearzi May 25, 2026
b62be5a
block_fp8: _skip_init plumbing for QuantizedLinear/SwitchLinear
yohann-bearzi May 25, 2026
9b00e38
block_fp8: relax scale-shape validator for fused-QKV padding
yohann-bearzi May 26, 2026
9b6e76c
block_fp8: skip nax routing for block_fp8 mode
yohann-bearzi May 26, 2026
3200bb7
block_fp8: qmm_t kernel + nax/splitk routing for prefill
yohann-bearzi May 26, 2026
849e2d0
debug: MLX_TRACE_KERNELS env var to log kernel dispatches
yohann-bearzi May 26, 2026
53957a4
block_fp8: gather_qmm_t kernel for MoE prefill
yohann-bearzi May 27, 2026
072ced0
block_fp8: tiled qmm_t using BlockFp8QuantizedLoader
yohann-bearzi May 27, 2026
23c18b2
block_fp8: tiled gather_qmm_t and gather_qmm_rhs kernels
yohann-bearzi May 27, 2026
f5cf82c
sdpa: enable fused vector kernel for asymmetric Q/V head_dim (192, 128)
yohann-bearzi May 27, 2026
df26871
block_fp8: qmm_t_splitk kernel for short-M matmul
yohann-bearzi May 28, 2026
6edc5da
block_fp8: implement quantize/dequantize for the real 2D-block F32-sc…
yohann-bearzi May 28, 2026
92cb277
tests: block_fp8 quantize/dequantize round-trip + kernel coverage
yohann-bearzi May 28, 2026
dc76957
block_fp8: per-shard scale-row indexing for TP-interleaved fused qkv
yohann-bearzi May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,062 changes: 932 additions & 130 deletions mlx/backend/metal/kernels/fp_quantized.h

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions mlx/backend/metal/kernels/fp_quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,117 @@
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)

// ----- block_fp8 (DeepSeek-V3 / MiMo) instantiations -----
// Only qmv_fast is implemented in the first pass. qmv_quad, qmv, qvm, qmm,
// gather variants all produce "kernel not found" until subsequent patches.
#define instantiate_block_fp8_qmv_fast(type) \
instantiate_kernel( \
"block_fp8_qmv_fast_" #type "_gs_128_b_8_batch_0", \
block_fp8_qmv_fast, type, 128, 8, false) \
instantiate_kernel( \
"block_fp8_qmv_fast_" #type "_gs_128_b_8_batch_1", \
block_fp8_qmv_fast, type, 128, 8, true)

instantiate_block_fp8_qmv_fast(float)
instantiate_block_fp8_qmv_fast(bfloat16_t)
instantiate_block_fp8_qmv_fast(float16_t)

#define instantiate_block_fp8_gather_qmv_fast(type) \
instantiate_kernel( \
"block_fp8_gather_qmv_fast_" #type "_gs_128_b_8", \
block_fp8_gather_qmv_fast, type, 128, 8)

instantiate_block_fp8_gather_qmv_fast(float)
instantiate_block_fp8_gather_qmv_fast(bfloat16_t)
instantiate_block_fp8_gather_qmv_fast(float16_t)

// qmm_t: prefill matmul. Per-row qmv math, 8 N rows per threadgroup.
// 4 variants per type: (aligned_N x batched). Dispatcher selects the right
// suffix based on N%32 and batch size.

#define instantiate_block_fp8_qmm_t_one(type, alN, alN_str, batched, batched_str) \
template [[host_name("block_fp8_qmm_t_" #type "_gs_128_b_8_alN_" #alN_str "_batch_" #batched_str)]] \
[[kernel]] void block_fp8_qmm_t<type, 128, 8, alN, batched>( \
const device uint8_t* w, const device float* scales, \
const device type* x, device type* y, \
const constant int& K, const constant int& N, const constant int& M, \
const constant int& scale_rows, \
const constant int& x_batch_ndims, const constant int* x_shape, \
const constant int64_t* x_strides, const constant int& w_batch_ndims, \
const constant int* w_shape, const constant int64_t* w_strides, \
const constant int64_t* s_strides, \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);

#define instantiate_block_fp8_qmm_t(type) \
instantiate_block_fp8_qmm_t_one(type, true, true, false, 0) \
instantiate_block_fp8_qmm_t_one(type, true, true, true, 1) \
instantiate_block_fp8_qmm_t_one(type, false, false, false, 0) \
instantiate_block_fp8_qmm_t_one(type, false, false, true, 1)

instantiate_block_fp8_qmm_t(float)
instantiate_block_fp8_qmm_t(bfloat16_t)
instantiate_block_fp8_qmm_t(float16_t)

#define instantiate_block_fp8_gather_qmm_t_one(type, alN, alN_str) \
template [[host_name("block_fp8_gather_qmm_t_" #type "_gs_128_b_8_alN_" #alN_str)]] \
[[kernel]] void block_fp8_gather_qmm_t<type, 128, 8, alN>( \
const device uint8_t* w, const device float* scales, \
const device type* x, \
const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, \
device type* y, \
const constant int& K, const constant int& N, const constant int& M, \
const constant int& x_batch_ndims, const constant int* x_shape, \
const constant int64_t* x_strides, const constant int& w_batch_ndims, \
const constant int* w_shape, const constant int64_t* w_strides, \
const constant int64_t* s_strides, const constant int& batch_ndims, \
const constant int* batch_shape, const constant int64_t* lhs_strides, \
const constant int64_t* rhs_strides, \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);

#define instantiate_block_fp8_gather_qmm_t(type) \
instantiate_block_fp8_gather_qmm_t_one(type, true, true) \
instantiate_block_fp8_gather_qmm_t_one(type, false, false)

instantiate_block_fp8_gather_qmm_t(float)
instantiate_block_fp8_gather_qmm_t(bfloat16_t)
instantiate_block_fp8_gather_qmm_t(float16_t)

#define instantiate_block_fp8_qmm_t_splitk_one(type, alN, alN_str) \
template [[host_name("block_fp8_qmm_t_splitk_" #type "_gs_128_b_8_alN_" #alN_str)]] \
[[kernel]] void block_fp8_qmm_t_splitk<type, 128, 8, alN>( \
const device uint8_t* w [[buffer(0)]], \
const device float* scales [[buffer(1)]], \
const device type* x [[buffer(2)]], \
device type* y [[buffer(3)]], \
const constant int& K [[buffer(4)]], \
const constant int& N [[buffer(5)]], \
const constant int& M [[buffer(6)]], \
const constant int& k_partition_size [[buffer(7)]], \
const constant int& split_k_partition_stride [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);

#define instantiate_block_fp8_qmm_t_splitk(type) \
instantiate_block_fp8_qmm_t_splitk_one(type, true, true) \
instantiate_block_fp8_qmm_t_splitk_one(type, false, false)

instantiate_block_fp8_qmm_t_splitk(float)
instantiate_block_fp8_qmm_t_splitk(bfloat16_t)
instantiate_block_fp8_qmm_t_splitk(float16_t)

instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nt, float, 16, 32, 32, 1, 2, true, block_fp8, 128, 8)
instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nn, float, 16, 32, 32, 1, 2, false, block_fp8, 128, 8)
instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nt, bfloat16_t, 16, 32, 32, 1, 2, true, block_fp8, 128, 8)
instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nn, bfloat16_t, 16, 32, 32, 1, 2, false, block_fp8, 128, 8)
instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nt, float16_t, 16, 32, 32, 1, 2, true, block_fp8, 128, 8)
instantiate_gather_qmm_rhs(block_fp8_gather_qmm_rhs, gather_qmm_rhs_nn, float16_t, 16, 32, 32, 1, 2, false, block_fp8, 128, 8)
// clang-format on
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 192, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
Expand Down
44 changes: 39 additions & 5 deletions mlx/backend/metal/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,19 @@ auto get_quantized_kernel_wrapped(
int bits,
Args... args) {
std::string template_def;
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
std::string fname;
if (mode == "affine") {
fname = "affine_" + func;
} else if (mode == "block_fp8") {
fname = "block_fp8_" + func;
} else {
fname = "fp_" + func;
}
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
if (std::getenv("MLX_TRACE_KERNELS")) {
fprintf(stderr, "[mlx-trace] %s\n", name.c_str());
}
return get_quantized_kernel(d, name, template_def, mode);
}

Expand All @@ -44,9 +54,19 @@ auto get_qmm_nax_kernel_wrapped(
int bits,
Args... args) {
std::string template_def;
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
std::string fname;
if (mode == "affine") {
fname = "affine_" + func;
} else if (mode == "block_fp8") {
fname = "block_fp8_" + func;
} else {
fname = "fp_" + func;
}
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
if (std::getenv("MLX_TRACE_KERNELS")) {
fprintf(stderr, "[mlx-trace] %s\n", name.c_str());
}
return get_qmm_nax_kernel(d, name, template_def, mode);
}

Expand Down Expand Up @@ -290,6 +310,9 @@ void qmv(
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
if (mode == "block_fp8") {
compute_encoder.set_bytes((int)scales.shape(-2), c++);
}
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);

compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
Expand Down Expand Up @@ -693,7 +716,8 @@ void qmm(
const Stream& s,
const std::string& mode) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
(env::enable_tf32() || x.dtype() != float32) &&
mode != "block_fp8") {
return qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
Expand All @@ -720,6 +744,8 @@ void qmm(
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);

// block_fp8 qmm_t v2 uses the standard tiled geometry (no override needed).

std::string kname;
kname.reserve(64);
bool aligned = N % 32 == 0;
Expand Down Expand Up @@ -766,6 +792,9 @@ void qmm(
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
if (mode == "block_fp8") {
compute_encoder.set_bytes((int)scales.shape(-2), c++);
}
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);

compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
Expand Down Expand Up @@ -827,6 +856,9 @@ void qmm_splitk(
std::string type_string = get_type_string(x.dtype());
std::string kname;
kname.reserve(64);
if (mode == "block_fp8" && (int)scales.shape(-2) != (N + 127) / 128) {
throw std::runtime_error("block_fp8: per-shard-padded scales unsupported on split-k path");
}
concatenate(
kname,
mode + "_qmm_t_splitk_",
Expand Down Expand Up @@ -884,7 +916,8 @@ void gather_qmm(
const Stream& s,
const std::string& mode) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
(env::enable_tf32() || x.dtype() != float32) &&
mode != "block_fp8") {
return gather_qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
Expand Down Expand Up @@ -1229,7 +1262,8 @@ void gather_qmm_rhs(
const Stream& s,
const std::string mode) {
if (metal::is_nax_available() && transpose &&
(env::enable_tf32() || x_.dtype() != float32)) {
(env::enable_tf32() || x_.dtype() != float32) &&
mode != "block_fp8") {
return gather_qmm_rhs_nax(
/* const array& x_ = */ x_,
/* const array& w_ = */ w_,
Expand Down
7 changes: 4 additions & 3 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,10 @@ bool ScaledDotProductAttention::use_fallback(
const int gqa_factor = num_query_heads / num_kv_heads;

const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
(query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 ||
query_head_dim == 128 || query_head_dim == 256)) ||
(query_head_dim == 192 && value_head_dim == 128);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);

Expand Down
Loading