Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mlx/backend/cuda/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ bool ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (s.device == Device::cpu) {
Expand Down Expand Up @@ -460,7 +459,16 @@ void ScaledDotProductAttention::eval_gpu(
}
}

bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
bool ScaledDotProductAttentionVJP::use_fallback(
const array& q,
Stream s,
bool has_mask,
bool has_sinks,
int /* n_kv_heads */) {
// Force unfused attention when masks/sinks present
if (has_mask || has_sinks) {
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.

}
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
build_kernel(scaled_dot_product_attention sdpa_vector.h sdpa_vector_vjp.h)
if(MLX_METAL_VERSION GREATER_EQUAL 320)
build_kernel(fence)
endif()
Expand Down
39 changes: 39 additions & 0 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/sdpa_vector_vjp.h"

using namespace metal;

Expand Down Expand Up @@ -41,4 +42,42 @@ using namespace metal;
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)

// SDPA vector VJP instantiations
#define instantiate_sdpa_vector_vjp(type, qk_dim, value_dim) \
instantiate_kernel( \
"sdpa_vector_vjp_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector_vjp, \
type, \
qk_dim, \
value_dim)

// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
#define instantiate_sdpa_vector_vjp_heads(type) \
instantiate_sdpa_vector_vjp(type, 64, 64) \
instantiate_sdpa_vector_vjp(type, 96, 96) \
instantiate_sdpa_vector_vjp(type, 128, 128)

instantiate_sdpa_vector_vjp_heads(float)
instantiate_sdpa_vector_vjp_heads(bfloat16_t)
instantiate_sdpa_vector_vjp_heads(float16_t)

// SDPA vector VJP accumulate instantiations (for half/bfloat16 with float32 accumulators)
#define instantiate_sdpa_vector_vjp_accumulate(type, qk_dim, value_dim) \
instantiate_kernel( \
"sdpa_vector_vjp_accumulate_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector_vjp_accumulate, \
type, \
qk_dim, \
value_dim)

// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
#define instantiate_sdpa_vector_vjp_accumulate_heads(type) \
instantiate_sdpa_vector_vjp_accumulate(type, 64, 64) \
instantiate_sdpa_vector_vjp_accumulate(type, 96, 96) \
instantiate_sdpa_vector_vjp_accumulate(type, 128, 128)

// Note: Only instantiate for half/bfloat16 since float32 doesn't need accumulate variant
instantiate_sdpa_vector_vjp_accumulate_heads(bfloat16_t)
instantiate_sdpa_vector_vjp_accumulate_heads(float16_t)
// clang-format on
41 changes: 25 additions & 16 deletions mlx/backend/metal/kernels/sdpa_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ template <typename T, int D, int V = D>
out += o_offset * V + simd_gid * v_per_thread;

// Read the query and 0 the output accumulator
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
q[i] = log2e_scale * queries[i];
}
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
Expand All @@ -90,7 +92,9 @@ template <typename T, int D, int V = D>
U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
// Scale sink by M_LOG2E_F to match log2 domain
max_score = static_cast<U>(M_LOG2E_F) *
static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}

Expand All @@ -117,13 +121,14 @@ template <typename T, int D, int V = D>
}
score = simd_sum(score);
if (float_mask) {
score += static_cast<U>(fmask[0]);
// Scale float mask by M_LOG2E_F to match log2 domain
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
}

// Update the accumulators
// Update the accumulators (using exp2 to match STEEL attention)
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
U factor = fast::exp2(max_score - new_max);
U exp_score = fast::exp2(score - new_max);

max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
Expand Down Expand Up @@ -155,7 +160,7 @@ template <typename T, int D, int V = D>
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U factor = fast::exp2(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);

// Now we need to aggregate all the outputs
Expand Down Expand Up @@ -252,8 +257,10 @@ template <typename T, int D, int V = D>
maxs += o_offset * blocks + block_idx;

// Read the query and 0 the output accumulator
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
q[i] = log2e_scale * queries[i];
}
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
Expand All @@ -263,7 +270,8 @@ template <typename T, int D, int V = D>
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
// Scale sink by M_LOG2E_F to match log2 domain
max_score = static_cast<U>(M_LOG2E_F) * static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}

Expand Down Expand Up @@ -291,13 +299,14 @@ template <typename T, int D, int V = D>
score = simd_sum(score);

if (float_mask) {
score += fmask[0];
// Scale float mask by M_LOG2E_F to match log2 domain
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
}

// Update the accumulators
// Update the accumulators (using exp2 to match STEEL attention)
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
U factor = fast::exp2(max_score - new_max);
U exp_score = fast::exp2(score - new_max);

max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
Expand Down Expand Up @@ -329,7 +338,7 @@ template <typename T, int D, int V = D>
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U factor = fast::exp2(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);

Expand All @@ -342,7 +351,7 @@ template <typename T, int D, int V = D>
// Now we need to aggregate all the outputs
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
o[i] * fast::exp2(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

// And write the output
Expand Down Expand Up @@ -390,7 +399,7 @@ template <typename T, int D>
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U factor = fast::exp2(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);

// Now read the block into registers and then use shared memory to transpose
Expand Down
Loading
Loading