Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "infiniop/ops/zeros.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"

#endif // __INFINIOP_API_H__
79 changes: 79 additions & 0 deletions include/infiniop/ops/paged_attention_prefill.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__

#include "../operator_descriptor.h"

// Define an opaque handle for the Paged Attention Prefill descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;

/**
* @brief Creates a descriptor for the Paged Attention Prefill operation.
* * @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor.
* @param k_cache_desc Descriptor for the global physical key cache.
* @param v_cache_desc Descriptor for the global physical value cache.
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);

/**
* @brief Retrieves the workspace size required for the Paged Attention Prefill operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size);

/**
* @brief Executes the Paged Attention Prefill operation.
* * @param desc The Paged Attention Prefill descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA/device stream for the operation.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *cache_lens,
const void *seq_lens,
const void *alibi_slopes,
void *stream);

/**
* @brief Destroys a Paged Attention Prefill descriptor.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc);

#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
161 changes: 161 additions & 0 deletions src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__

#include <cuda_fp16.h>
#include <float.h>
#include <math.h>

namespace op::paged_attention_prefill::cuda {

// =============================================================
// Correctness prioritized.
// =============================================================

template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *cache_lens_,
const int64_t *seq_lens_,
const float *alibi_slopes_,
const size_t num_heads,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t max_new_len,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride,
const size_t head_size_const) {

// --- 1. Coordinate Setup ---
// Grid: (num_heads, max_new_len, num_seqs)
// Block: (HEAD_SIZE, 1, 1)
const int seq_idx = blockIdx.z;
const int q_token_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int dim_idx = threadIdx.x;

// Check boundary
const int64_t cur_new_len = seq_lens_[seq_idx];
if (q_token_idx >= cur_new_len) {
return;
}

// Safety check for head size
if (dim_idx >= head_size_const) {
return;
}

// Dimensions
const int64_t total_seq_len = cache_lens_[seq_idx];
const int64_t history_len = total_seq_len - cur_new_len;
const int64_t global_token_idx = history_len + q_token_idx;

const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;

const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;

const Tdata *q_ptr_base = q_ + seq_idx * q_stride + q_token_idx * (num_heads * head_size_const) + head_idx * head_size_const;

Tdata *out_ptr = out_ + seq_idx * o_stride + q_token_idx * (num_heads * head_size_const) + head_idx * head_size_const;

const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];

// --- Pass 1: Find Global Max ---
Tcompute max_score = -FLT_MAX;

for (int t = 0; t < total_seq_len; ++t) {
if (t > global_token_idx) {
break;
}

const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block = block_table[b_idx];

const Tdata *k_vec = k_cache_ + physical_block * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size_const;

Tcompute score = 0.0f;
for (int d = 0; d < head_size_const; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= scale;

if (alibi_slope != 0.0f) {
score += alibi_slope * (t - total_seq_len + 1);
}

if (score > max_score) {
max_score = score;
}
}

// --- Pass 2: Calculate Denominator (Sum Exp) ---
Tcompute sum_exp = 0.0f;

for (int t = 0; t < total_seq_len; ++t) {
if (t > global_token_idx) {
break;
}

const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size_const;

Tcompute score = 0.0f;
for (int d = 0; d < head_size_const; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * (t - total_seq_len + 1);
}

sum_exp += expf(score - max_score);
}

// --- Pass 3: Calculate Weighted Sum (V) ---
Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);

for (int t = 0; t < total_seq_len; ++t) {
if (t > global_token_idx) {
break;
}

const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block = block_table[b_idx];

// Re-compute Score
const Tdata *k_vec = k_cache_ + physical_block * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size_const;
Tcompute score = 0.0f;
for (int d = 0; d < head_size_const; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * (t - total_seq_len + 1);
}

Tcompute prob = expf(score - max_score) * inv_sum;

const Tdata *v_vec = v_cache_ + physical_block * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size_const;

acc += prob * static_cast<Tcompute>(v_vec[dim_idx]);
}

out_ptr[dim_idx] = static_cast<Tdata>(acc);
}

} // namespace op::paged_attention_prefill::cuda

#endif // __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
105 changes: 105 additions & 0 deletions src/infiniop/ops/paged_attention_prefill/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__

#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>

namespace op::paged_attention_prefill {

class PagedAttentionPrefillInfo {
PagedAttentionPrefillInfo() = default;

public:
// --- Data Types and Scale ---
infiniDtype_t dtype;
float scale;

// --- Shape Dimensions ---
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t max_num_blocks_per_seq;
size_t max_new_len;

// --- Strides for Memory Layout ---
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t o_stride;

static utils::Result<PagedAttentionPrefillInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc, // Total lengths
infiniopTensorDescriptor_t seq_lens_desc, // New lengths (Prefill length)
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {

auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

if (q_desc->ndim() < 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1 || seq_lens_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

if (block_tables_desc->dtype() != INFINI_DTYPE_I64 || cache_lens_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

// --- Extract shape dimensions ---
// Assuming Q shape: [num_seqs, max_new_len, num_heads, head_size]
auto q_shape = q_desc->shape();
auto k_cache_shape = k_cache_desc->shape();

size_t num_seqs = q_shape[0];
size_t max_new_len = q_shape[1];
size_t num_heads = q_shape[2];
size_t head_size = q_shape[3];

if (head_size != 128) {
std::cerr << "[Error] PagedAttentionPrefill now only supports head_size = 128, but got "
<< head_size << "." << std::endl;
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2];
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];

// --- Extract strides ---
ptrdiff_t q_stride = q_desc->stride(0); // Stride between sequences in Q
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);

return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
dtype,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_num_blocks_per_seq,
max_new_len,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
}
};

} // namespace op::paged_attention_prefill

#endif // __PAGED_ATTENTION_PREFILL_INFO_H__
Loading