-
Notifications
You must be signed in to change notification settings - Fork 87
issue/834: add paged attention for nvidia gpu #836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
spike-zhu
wants to merge
3
commits into
main
Choose a base branch
from
issue/834
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| #ifndef __INFINIOP_PAGED_ATTENTION_API_H__ | ||
| #define __INFINIOP_PAGED_ATTENTION_API_H__ | ||
|
|
||
| #include "../operator_descriptor.h" | ||
|
|
||
| // Define an opaque handle for the Paged Attention descriptor. | ||
| typedef struct InfiniopDescriptor *infiniopPagedAttentionDescriptor_t; | ||
|
|
||
| /** | ||
| * @brief Creates a descriptor for the Paged Attention v1 operation. | ||
| * | ||
| * This function initializes a descriptor that holds all the metadata needed | ||
| * for the paged attention computation. | ||
| * | ||
| * @param handle The handle to the InfiniOP library context. | ||
| * @param desc_ptr A pointer to store the created descriptor. | ||
| * @param out_desc Descriptor for the output tensor. | ||
| * @param q_desc Descriptor for the query tensor. | ||
| * @param k_cache_desc Descriptor for the key cache tensor. | ||
| * @param v_cache_desc Descriptor for the value cache tensor. | ||
| * @param block_tables_desc Descriptor for the block tables tensor. | ||
| * @param seq_lens_desc Descriptor for the sequence lengths tensor. | ||
| * @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL. | ||
| * @param scale The attention scaling factor. | ||
| * @param max_num_blocks_per_seq The maximum number of batched blocks tables. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopCreatePagedAttentionDescriptor( | ||
| infiniopHandle_t handle, | ||
| infiniopPagedAttentionDescriptor_t *desc_ptr, | ||
| infiniopTensorDescriptor_t out_desc, | ||
| infiniopTensorDescriptor_t q_desc, | ||
| infiniopTensorDescriptor_t k_cache_desc, | ||
| infiniopTensorDescriptor_t v_cache_desc, | ||
| infiniopTensorDescriptor_t block_tables_desc, | ||
| infiniopTensorDescriptor_t seq_lens_desc, | ||
| infiniopTensorDescriptor_t alibi_slopes_desc, | ||
| float scale); | ||
|
|
||
| /** | ||
| * @brief Retrieves the workspace size required for the Paged Attention operation. | ||
| * | ||
| * @param desc The Paged Attention descriptor. | ||
| * @param size A pointer to store the required workspace size in bytes. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( | ||
| infiniopPagedAttentionDescriptor_t desc, size_t *size); | ||
|
|
||
| /** | ||
| * @brief Executes the Paged Attention v1 operation. | ||
| * | ||
| * @param desc The Paged Attention descriptor. | ||
| * @param workspace Pointer to the workspace memory. | ||
| * @param workspace_size The size of the workspace. | ||
| * @param out Pointer to the output tensor data. | ||
| * @param q Pointer to the query tensor data. | ||
| * @param k_cache Pointer to the key cache data. | ||
| * @param v_cache Pointer to the value cache data. | ||
| * @param block_tables Pointer to the block tables data. | ||
| * @param seq_lens Pointer to the sequence lengths data. | ||
| * @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL. | ||
| * @param stream The CUDA stream for the operation. Can be NULL. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopPagedAttention( | ||
| infiniopPagedAttentionDescriptor_t desc, | ||
| void *workspace, | ||
| size_t workspace_size, | ||
| void *out, | ||
| const void *q, | ||
| const void *k_cache, | ||
| const void *v_cache, | ||
| const void *block_tables, | ||
| const void *seq_lens, | ||
| const void *alibi_slopes, | ||
| void *stream); | ||
|
|
||
| /** | ||
| * @brief Destroys a Paged Attention descriptor. | ||
| * | ||
| * @param desc The descriptor to be destroyed. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopDestroyPagedAttentionDescriptor( | ||
| infiniopPagedAttentionDescriptor_t desc); | ||
|
|
||
| #endif // __INFINIOP_PAGED_ATTENTION_API_H__ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| #ifndef __INFINIOP_PAGED_CACHING_API_H__ | ||
| #define __INFINIOP_PAGED_CACHING_API_H__ | ||
|
|
||
| #include "../operator_descriptor.h" | ||
|
|
||
| // Define an opaque handle for the Paged Caching descriptor. | ||
| typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t; | ||
|
|
||
| /** | ||
| * @brief Creates a descriptor for the Paged Caching operation. | ||
| * | ||
| * This function initializes a descriptor that holds all the metadata needed | ||
| * to copy key/value vectors into their respective cache pools. | ||
| * | ||
| * @param handle The handle to the InfiniOP library context. | ||
| * @param desc_ptr A pointer to store the created descriptor. | ||
| * @param k_desc Descriptor for the source key tensor. | ||
| * @param v_desc Descriptor for the source value tensor. | ||
| * @param k_cache_desc Descriptor for the key cache pool tensor. | ||
| * @param v_cache_desc Descriptor for the value cache pool tensor. | ||
| * @param slot_mapping_desc Descriptor for the slot mapping tensor. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopCreatePagedCachingDescriptor( | ||
| infiniopHandle_t handle, | ||
| infiniopPagedCachingDescriptor_t *desc_ptr, | ||
| infiniopTensorDescriptor_t k_desc, | ||
| infiniopTensorDescriptor_t v_desc, | ||
| infiniopTensorDescriptor_t k_cache_desc, | ||
| infiniopTensorDescriptor_t v_cache_desc, | ||
| infiniopTensorDescriptor_t slot_mapping_desc); | ||
|
|
||
| /** | ||
| * @brief Retrieves the workspace size required for the Paged Caching operation. | ||
| * | ||
| * @param desc The Paged Caching descriptor. | ||
| * @param size A pointer to store the required workspace size in bytes (typically 0). | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize( | ||
| infiniopPagedCachingDescriptor_t desc, size_t *size); | ||
|
|
||
| /** | ||
| * @brief Executes the Paged Caching operation. | ||
| * | ||
| * @param desc The Paged Caching descriptor. | ||
| * @param workspace Pointer to the workspace memory. | ||
| * @param workspace_size The size of the workspace. | ||
| * @param k Pointer to the source key tensor data. | ||
| * @param v Pointer to the source value tensor data. | ||
| * @param k_cache Pointer to the key cache pool data. | ||
| * @param v_cache Pointer to the value cache pool data. | ||
| * @param slot_mapping Pointer to the slot mapping data. | ||
| * @param stream The CUDA stream for the operation. Can be NULL. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopPagedCaching( | ||
| infiniopPagedCachingDescriptor_t desc, | ||
| void *workspace, | ||
| size_t workspace_size, | ||
| const void *k, | ||
| const void *v, | ||
| void *k_cache, | ||
| void *v_cache, | ||
| const void *slot_mapping, | ||
| void *stream); | ||
|
|
||
| /** | ||
| * @brief Destroys a Paged Caching descriptor. | ||
| * | ||
| * @param desc The descriptor to be destroyed. | ||
| * @return infiniStatus_t Status code of the operation. | ||
| */ | ||
| __C __export infiniStatus_t infiniopDestroyPagedCachingDescriptor( | ||
| infiniopPagedCachingDescriptor_t desc); | ||
|
|
||
| #endif // __INFINIOP_PAGED_CACHING_API_H__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| #ifndef __PAGED_ATTENTION_KERNEL_CUH__ | ||
| #define __PAGED_ATTENTION_KERNEL_CUH__ | ||
|
|
||
| // This kernel is refactored to be high-performance, adopting parallelism strategies | ||
| // from industry-standard implementations like vLLM. It fixes functional and performance | ||
| // issues in the original draft. | ||
|
|
||
| namespace op::paged_attention::cuda { | ||
|
|
||
| template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS> | ||
| __device__ void pagedAttentionKernel( | ||
| Tdata *out_, | ||
| const Tdata *q_, | ||
| const Tdata *k_cache_, | ||
| const Tdata *v_cache_, | ||
| const int32_t *block_tables_, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用template或者更大的数据类型,int32容易溢出 |
||
| const int32_t *seq_lens_, | ||
| const float *alibi_slopes_, | ||
| const size_t num_kv_heads, | ||
| const float scale, | ||
| const size_t max_num_blocks_per_seq, | ||
| const size_t block_size, | ||
| const ptrdiff_t q_stride, | ||
| const ptrdiff_t kv_block_stride, | ||
| const ptrdiff_t kv_head_stride, | ||
| const ptrdiff_t o_stride) { | ||
| //================================================================================ | ||
| // 1. Setup & Query Loading (No changes in this section) | ||
| //================================================================================ | ||
| const int seq_idx = blockIdx.y; | ||
| const int head_idx = blockIdx.x; | ||
| const int num_heads = gridDim.x; | ||
| const int32_t seq_len = seq_lens_[seq_idx]; | ||
| if (seq_len == 0) { | ||
| return; | ||
| } | ||
|
|
||
| const size_t num_queries_per_kv = num_heads / num_kv_heads; | ||
| const size_t kv_head_idx = head_idx / num_queries_per_kv; | ||
| const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; | ||
|
|
||
| const int32_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; | ||
|
|
||
| const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; | ||
| Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; | ||
|
|
||
| extern __shared__ char shared_mem_char[]; | ||
| Tcompute *shared_mem = reinterpret_cast<Tcompute *>(shared_mem_char); | ||
| Tcompute *q_shared = shared_mem; | ||
| Tcompute *logits = shared_mem + HEAD_SIZE; | ||
|
|
||
| // printf("static_cast<Tcompute>(q_ptr[i]);"); | ||
| for (size_t i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { | ||
| q_shared[i] = static_cast<Tcompute>(q_ptr[i]); | ||
| } | ||
| __syncthreads(); | ||
| //================================================================================ | ||
| // 2. Compute QK Dot Product & Find Max Logit | ||
| //================================================================================ | ||
| for (size_t token_idx = threadIdx.x; token_idx < seq_len; token_idx += NUM_THREADS) { | ||
| const int32_t block_idx = token_idx / block_size; | ||
| const int32_t token_in_block_idx = token_idx % block_size; | ||
| const int32_t physical_block_num = block_table[block_idx]; | ||
|
|
||
| const Tdata *k_vec_ptr = k_cache_ + physical_block_num * kv_block_stride + kv_head_idx * kv_head_stride + token_in_block_idx * HEAD_SIZE; | ||
|
|
||
| Tcompute qk = 0.0f; | ||
| #pragma unroll | ||
| for (size_t i = 0; i < HEAD_SIZE / 8; ++i) { | ||
| const size_t offset = i * 8; | ||
|
|
||
| // 手动展开8次计算 | ||
| qk += q_shared[offset + 0] * static_cast<Tcompute>(k_vec_ptr[offset + 0]); | ||
| qk += q_shared[offset + 1] * static_cast<Tcompute>(k_vec_ptr[offset + 1]); | ||
| qk += q_shared[offset + 2] * static_cast<Tcompute>(k_vec_ptr[offset + 2]); | ||
| qk += q_shared[offset + 3] * static_cast<Tcompute>(k_vec_ptr[offset + 3]); | ||
| qk += q_shared[offset + 4] * static_cast<Tcompute>(k_vec_ptr[offset + 4]); | ||
| qk += q_shared[offset + 5] * static_cast<Tcompute>(k_vec_ptr[offset + 5]); | ||
| qk += q_shared[offset + 6] * static_cast<Tcompute>(k_vec_ptr[offset + 6]); | ||
| qk += q_shared[offset + 7] * static_cast<Tcompute>(k_vec_ptr[offset + 7]); | ||
| } | ||
|
|
||
| qk *= scale; | ||
| if (alibi_slope != 0.0f) { | ||
| qk += alibi_slope * (token_idx - seq_len + 1); | ||
| } | ||
|
|
||
| logits[token_idx] = qk; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| __shared__ Tcompute global_qk_max; | ||
| Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max<NUM_THREADS, Tcompute>(logits, seq_len); | ||
|
|
||
| if (threadIdx.x == 0) { | ||
| global_qk_max = global_qk_max_0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| //================================================================================ | ||
| // 3. Compute Softmax (No changes in this section) | ||
| //================================================================================ | ||
|
|
||
| for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) { | ||
| Tcompute val = expf(logits[i] - global_qk_max); // 使用全局最大值 | ||
| logits[i] = val; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| __shared__ Tcompute inv_sum; | ||
| Tcompute exp_sum_0 = op::common_cuda::reduce_op::sum<NUM_THREADS, Tcompute, Tcompute>(logits, seq_len); | ||
| if (threadIdx.x == 0) { | ||
| inv_sum = 1.0f / (exp_sum_0 + 1e-6f); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) { | ||
| logits[i] *= inv_sum; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| //================================================================================ | ||
| // 4. Aggregate Values (V) weighted by probabilities | ||
| //================================================================================ | ||
|
|
||
| for (size_t h_dim = threadIdx.x; h_dim < HEAD_SIZE; h_dim += NUM_THREADS) { | ||
| Tcompute acc = 0.0f; | ||
|
|
||
| for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { | ||
| const size_t block_idx = token_idx / block_size; | ||
| const size_t token_in_block_idx = token_idx % block_size; | ||
| const int32_t physical_block_num = block_table[block_idx]; | ||
| const Tcompute prob = logits[token_idx]; | ||
|
|
||
| const Tdata *v_vec_ptr = v_cache_ | ||
| + physical_block_num * kv_block_stride | ||
| + kv_head_idx * kv_head_stride | ||
| + token_in_block_idx * HEAD_SIZE; | ||
|
|
||
| const Tdata v_val = v_vec_ptr[h_dim]; | ||
| acc += prob * static_cast<Tcompute>(v_val); | ||
| } | ||
| out_ptr[h_dim] = static_cast<Tdata>(acc); | ||
| } | ||
| } | ||
|
|
||
| } // namespace op::paged_attention::cuda | ||
|
|
||
| #endif // __PAGED_ATTENTION_KERNEL_CUH__ | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请在描述中标出各张量的形状和含义