Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ void AppendAttentionKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
Expand Down Expand Up @@ -137,6 +140,9 @@ void AppendAttentionKernel(
qkv_out,
key_cache,
value_cache,
tmp_workspace,
tmp_m,
tmp_d,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales
: cache_k_dequant_scales,
Expand Down Expand Up @@ -446,6 +452,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
Expand Down Expand Up @@ -579,6 +588,9 @@ std::vector<paddle::Tensor> AppendAttention(
qkv,
key_cache,
value_cache,
tmp_workspace,
tmp_m,
tmp_d,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
Expand Down Expand Up @@ -655,6 +667,9 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
Expand Down Expand Up @@ -735,6 +750,9 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
qkv,
key_cache,
value_cache,
tmp_workspace,
tmp_m,
tmp_d,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
Expand Down Expand Up @@ -825,6 +843,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& tmp_workspace_shape,
const std::vector<int64_t>& tmp_m_shape,
const std::vector<int64_t>& tmp_d_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
Expand Down Expand Up @@ -890,6 +911,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& tmp_workspace_dtype,
const paddle::DataType& tmp_m_dtype,
const paddle::DataType& tmp_d_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
Expand Down Expand Up @@ -975,6 +999,9 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& tmp_workspace_shape,
const std::vector<int64_t>& tmp_m_shape,
const std::vector<int64_t>& tmp_d_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
Expand Down Expand Up @@ -1033,6 +1060,9 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& tmp_workspace_dtype,
const paddle::DataType& tmp_m_dtype,
const paddle::DataType& tmp_d_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
Expand Down Expand Up @@ -1091,6 +1121,9 @@ PD_BUILD_STATIC_OP(append_attention)
.Inputs({"qkv",
"key_cache",
"value_cache",
"tmp_workspace",
"tmp_m",
"tmp_d",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
Expand Down Expand Up @@ -1152,6 +1185,9 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
.Inputs({"qkv",
"key_cache",
"value_cache",
"tmp_workspace",
"tmp_m",
"tmp_d",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
Expand Down
85 changes: 53 additions & 32 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
template <typename T, typename OutT>
void CascadeAppendAttentionC16Kernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
const paddle::Tensor&
qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
const paddle::Tensor&
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -35,9 +39,8 @@ void CascadeAppendAttentionC16Kernel(
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand Down Expand Up @@ -99,6 +102,9 @@ void CascadeAppendAttentionC16Kernel(
qkv,
cache_k,
cache_v,
tmp_workspace,
tmp_m,
tmp_d,
attn_mask,
shift_bias,
smooth_weight,
Expand Down Expand Up @@ -127,13 +133,17 @@ void CascadeAppendAttentionC16Kernel(
})})})})})})
}

template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16>(
template void
CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor&
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -146,9 +156,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -174,13 +183,17 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
template void
CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor&
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -193,9 +206,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand Down Expand Up @@ -228,6 +240,9 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -240,9 +255,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand Down Expand Up @@ -272,24 +286,26 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor&
cache_k, // [max_block_num, num_heads, block_size, head_dim]
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
cache_v_scale, // [num_kv_heads, head_dim]
cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
cache_k_zp, // [num_kv_heads, head_dim]
cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
cache_v_zp, // [num_kv_heads, head_dim]
cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -315,13 +331,17 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
template void
CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor&
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -334,9 +354,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand Down Expand Up @@ -369,6 +388,9 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor&
cache_v, // [max_block_num, num_heads, head_dim, block_size]
paddle::Tensor& tmp_workspace,
paddle::Tensor& tmp_m,
paddle::Tensor& tmp_d,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>&
cache_k_scale, // [num_kv_heads, head_dim]
Expand All @@ -381,9 +403,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand Down
Loading