Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pa_lib.so
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## Usage

```bash
bash ./compile.sh
python ./pa_accuracy.py
python ./bench_pa_performance.py
```

`bench_pa_performance.py` compares the standalone `pa_lib.so` kernel against
`torch_npu.npu_incre_flash_attention` with **paged KV** (`block_table`, `actual_seq_lengths`,
`block_size`), using the same calling pattern as `ifa/bench_ifa_gpa_paged.py`.

## Reference performance

Bandwidth utilization for GQA decoding shapes (standalone `pa_lib.so`, fp16, measured on one
910B-class device; numbers vary by card and load):

| Case | TFLOP/s | GiB/s |
|------|---------|--------|
| Qwen3-0.6B b1 h16/kv8 kv2048 | 0.22 | 105 |
| Qwen3-8B b1 h32/kv8 kv4096 | 0.97 | 226 |
| Qwen3-8B b1 h32/kv8 kv8192 | 1.80 | 420 |
| Qwen3-8B b64 h32/kv8 kv2048 | 4.50 | 1050 |

Compare to `torch_npu.npu_incre_flash_attention` (paged KV, same shapes):

| Case | Standalone ms | IFA ms | Speedup (IFA / standalone) |
|------|---------------|--------|------------------------------|
| Qwen3-0.6B b1 h16/kv8 kv2048 | 0.0747 | 0.0767 | 1.03× |
| Qwen3-1.7B b1 h16/kv8 kv4096 | 0.0687 | 0.0782 | 1.14× |
| Qwen3-4B b1 h32/kv8 kv2048 | 0.0736 | 0.0789 | 1.07× |
| Qwen3-8B b1 h32/kv8 kv4096 | 0.0693 | 0.0781 | 1.13× |
| Qwen3-8B b1 h32/kv8 kv8192 | 0.0745 | 0.0767 | 1.03× |
| Qwen3-14B b1 h40/kv8 kv2048 | 0.0698 | 0.0792 | 1.14× |
| Qwen3-32B b1 h64/kv8 kv2048 | 0.0689 | 0.0767 | 1.11× |
| MHA b1 h32/kv32 kv2048 | 0.0705 | 0.0752 | 1.07× |
| Qwen3-8B b4 h32/kv8 kv2048 | 0.0732 | 0.0784 | 1.07× |
| Qwen3-8B b8 h32/kv8 kv2048 | 0.0921 | 0.0818 | 0.89× |
| Qwen3-8B b16 h32/kv8 kv2048 | 0.1219 | 0.1024 | 0.84× |
| Qwen3-8B b32 h32/kv8 kv2048 | 0.2637 | 0.2725 | 1.03× |
| Qwen3-8B b64 h32/kv8 kv2048 | 0.4770 | 0.5337 | 1.12× |

Table source: `python bench_pa_performance.py --warmup 5 --iters 20`
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include <cstdint>

#ifdef __CPU_SIM
#include <pto/pto-inst.hpp>
#endif

#include "tensor.h"

#ifdef __CPU_SIM
#ifndef __gm__
#define __gm__
#endif
#ifndef __aicore__
#define __aicore__ [aicore] // NOLINT(whitespace/braces)
#endif
#endif

#ifdef __CPU_SIM

#include <algorithm>
#include <cmath>
#include <cstring>

static float half_to_float(uint16_t h) {
uint32_t sign = static_cast<uint32_t>(h & 0x8000) << 16;
uint32_t exp = (h >> 10) & 0x1f;
uint32_t mant = h & 0x03ff;
uint32_t bits;
if (exp == 0) {
if (mant == 0) {
bits = sign;
} else {
exp = 1;
while ((mant & 0x0400) == 0) {
mant <<= 1;
--exp;
}
mant &= 0x03ff;
bits = sign | ((exp + 112) << 23) | (mant << 13);
}
} else if (exp == 31) {
bits = sign | 0x7f800000 | (mant << 13);
} else {
bits = sign | ((exp + 112) << 23) | (mant << 13);
}
float out;
std::memcpy(&out, &bits, sizeof(out));
return out;
}

static uint16_t float_to_half(float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
uint32_t sign = (bits >> 16) & 0x8000;
int32_t exp = static_cast<int32_t>((bits >> 23) & 0xff) - 127 + 15;
uint32_t mant = bits & 0x7fffff;
if (exp <= 0) {
if (exp < -10) return static_cast<uint16_t>(sign);
mant = (mant | 0x800000) >> (1 - exp);
return static_cast<uint16_t>(sign | ((mant + 0x1000) >> 13));
}
if (exp >= 31) return static_cast<uint16_t>(sign | 0x7c00);
return static_cast<uint16_t>(sign | (static_cast<uint32_t>(exp) << 10) | ((mant + 0x1000) >> 13));
}

extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
#ifdef __DAV_VEC__
(void)args;
return;
#else
auto *query_t = reinterpret_cast<Tensor *>(args[0]);
auto *key_t = reinterpret_cast<Tensor *>(args[1]);
auto *value_t = reinterpret_cast<Tensor *>(args[2]);
auto *block_table_t = reinterpret_cast<Tensor *>(args[3]);
auto *out_t = reinterpret_cast<Tensor *>(args[4]);

auto *query = reinterpret_cast<uint16_t *>(query_t->buffer.addr) + query_t->start_offset;
auto *key = reinterpret_cast<uint16_t *>(key_t->buffer.addr) + key_t->start_offset;
auto *value = reinterpret_cast<uint16_t *>(value_t->buffer.addr) + value_t->start_offset;
auto *block_table = reinterpret_cast<int32_t *>(block_table_t->buffer.addr) + block_table_t->start_offset;
auto *out = reinterpret_cast<uint16_t *>(out_t->buffer.addr) + out_t->start_offset;

const int batch = static_cast<int>(query_t->shapes[0]);
const int num_heads = static_cast<int>(query_t->shapes[1]);
const int head_dim = static_cast<int>(query_t->shapes[2]);
const int block_size = static_cast<int>(key_t->shapes[1]);
const int num_kv_heads = static_cast<int>(key_t->shapes[2]);
const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch;
const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]);
const int heads_per_kv = num_heads / num_kv_heads;
const int seq_len = blocks_per_batch * block_size;
const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
Comment on lines +83 to +92
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively validate input dimensions to prevent potential division-by-zero or undefined behavior in CPU simulation if batch, num_kv_heads, head_dim, or block_size are zero or negative.

    const int batch = static_cast<int>(query_t->shapes[0]);
    const int num_heads = static_cast<int>(query_t->shapes[1]);
    const int head_dim = static_cast<int>(query_t->shapes[2]);
    const int block_size = static_cast<int>(key_t->shapes[1]);
    const int num_kv_heads = static_cast<int>(key_t->shapes[2]);
    if (batch <= 0 || num_kv_heads <= 0 || head_dim <= 0 || block_size <= 0) {
        return;
    }
    const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch;
    const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]);
    const int heads_per_kv = num_heads / num_kv_heads;
    const int seq_len = blocks_per_batch * block_size;
    const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
References
  1. Always perform defensive validation and normalization/fixups on input parameters before deriving other dependent variables from them, even if the invalid input is theoretically unreachable in practice.


for (int b = 0; b < batch; ++b) {
for (int h = 0; h < num_heads; ++h) {
const int kv_head = h / heads_per_kv;
float max_score = -INFINITY;
for (int token = 0; token < seq_len; ++token) {
const int block_col = std::min(token / block_size, max_blocks_per_query - 1);
const int block_id = block_table[b * max_blocks_per_query + block_col];
const int block_token = token % block_size;
float score = 0.0f;
for (int d = 0; d < head_dim; ++d) {
const int q_idx = (b * num_heads + h) * head_dim + d;
const int k_idx = ((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + d;
score += half_to_float(query[q_idx]) * half_to_float(key[k_idx]);
}
max_score = std::max(max_score, score * scale);
}

float denom = 0.0f;
for (int d = 0; d < head_dim; ++d) {
float accum = 0.0f;
for (int token = 0; token < seq_len; ++token) {
const int block_col = std::min(token / block_size, max_blocks_per_query - 1);
const int block_id = block_table[b * max_blocks_per_query + block_col];
const int block_token = token % block_size;
float score = 0.0f;
for (int kd = 0; kd < head_dim; ++kd) {
const int q_idx = (b * num_heads + h) * head_dim + kd;
const int k_idx =
((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + kd;
score += half_to_float(query[q_idx]) * half_to_float(key[k_idx]);
}
const float weight = std::exp(score * scale - max_score);
if (d == 0) denom += weight;
const int v_idx = ((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + d;
accum += weight * half_to_float(value[v_idx]);
}
const int out_idx = (b * num_heads + h) * head_dim + d;
out[out_idx] = float_to_half(accum / denom);
}
}
}
#endif
}

#else

#define block_idx get_block_idx()
#define block_num get_block_num()
#define PTO_PA_NO_GLOBAL_ENTRY
#include "../kernel/pa_entry.cce"
#undef PTO_PA_NO_GLOBAL_ENTRY
#undef block_num
#undef block_idx

static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
__gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
}
Comment on lines +148 to +151
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively check if tensor is null before dereferencing tensor->buffer.addr to prevent potential null pointer dereferences.

static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
    __gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
    if (tensor == nullptr) {
        return nullptr;
    }
    return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
}


extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
__gm__ uint8_t *q_gm = tensor_data(args, 0);
__gm__ uint8_t *k_gm = tensor_data(args, 1);
__gm__ uint8_t *v_gm = tensor_data(args, 2);
__gm__ uint8_t *block_tables_gm = tensor_data(args, 3);
__gm__ uint8_t *o_gm = tensor_data(args, 4);
__gm__ uint8_t *s_gm = tensor_data(args, 5);
__gm__ uint8_t *p_gm = tensor_data(args, 6);
__gm__ uint8_t *o_tmp_gm = tensor_data(args, 7);
__gm__ uint8_t *go_gm = tensor_data(args, 8);
__gm__ uint8_t *o_core_tmp_gm = tensor_data(args, 9);
__gm__ uint8_t *l_gm = tensor_data(args, 10);
__gm__ uint8_t *gm_k16 = tensor_data(args, 11);
__gm__ uint8_t *gm_v16 = tensor_data(args, 12);
__gm__ uint8_t *tiling_para_gm = tensor_data(args, 13);
__gm__ uint8_t *null_gm = tensor_data(args, 14);

paged_attention_mask_body(
nullptr,
q_gm,
k_gm,
v_gm,
block_tables_gm,
null_gm,
null_gm,
null_gm,
null_gm,
null_gm,
null_gm,
null_gm,
null_gm,
null_gm,
o_gm,
s_gm,
p_gm,
o_tmp_gm,
go_gm,
o_core_tmp_gm,
l_gm,
gm_k16,
gm_v16,
tiling_para_gm
);
}

#endif
Loading
Loading