-
Notifications
You must be signed in to change notification settings - Fork 52
High performance Paged Attention example #899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MirkoDeVita98
wants to merge
1
commit into
hw-native-sys:main
Choose a base branch
from
MirkoDeVita98:pr-655-work
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
1 change: 1 addition & 0 deletions
1
tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/.gitignore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| pa_lib.so |
43 changes: 43 additions & 0 deletions
43
...t/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| ## Usage | ||
|
|
||
| ```bash | ||
| bash ./compile.sh | ||
| python ./pa_accuracy.py | ||
| python ./bench_pa_performance.py | ||
| ``` | ||
|
|
||
| `bench_pa_performance.py` compares the standalone `pa_lib.so` kernel against | ||
| `torch_npu.npu_incre_flash_attention` with **paged KV** (`block_table`, `actual_seq_lengths`, | ||
| `block_size`), using the same calling pattern as `ifa/bench_ifa_gpa_paged.py`. | ||
|
|
||
| ## Reference performance | ||
|
|
||
| Bandwidth utilization for GQA decoding shapes (standalone `pa_lib.so`, fp16, measured on one | ||
| 910B-class device; numbers vary by card and load): | ||
|
|
||
| | Case | TFLOP/s | GiB/s | | ||
| |------|---------|--------| | ||
| | Qwen3-0.6B b1 h16/kv8 kv2048 | 0.22 | 105 | | ||
| | Qwen3-8B b1 h32/kv8 kv4096 | 0.97 | 226 | | ||
| | Qwen3-8B b1 h32/kv8 kv8192 | 1.80 | 420 | | ||
| | Qwen3-8B b64 h32/kv8 kv2048 | 4.50 | 1050 | | ||
|
|
||
| Compare to `torch_npu.npu_incre_flash_attention` (paged KV, same shapes): | ||
|
|
||
| | Case | Standalone ms | IFA ms | Speedup (IFA / standalone) | | ||
| |------|---------------|--------|------------------------------| | ||
| | Qwen3-0.6B b1 h16/kv8 kv2048 | 0.0747 | 0.0767 | 1.03× | | ||
| | Qwen3-1.7B b1 h16/kv8 kv4096 | 0.0687 | 0.0782 | 1.14× | | ||
| | Qwen3-4B b1 h32/kv8 kv2048 | 0.0736 | 0.0789 | 1.07× | | ||
| | Qwen3-8B b1 h32/kv8 kv4096 | 0.0693 | 0.0781 | 1.13× | | ||
| | Qwen3-8B b1 h32/kv8 kv8192 | 0.0745 | 0.0767 | 1.03× | | ||
| | Qwen3-14B b1 h40/kv8 kv2048 | 0.0698 | 0.0792 | 1.14× | | ||
| | Qwen3-32B b1 h64/kv8 kv2048 | 0.0689 | 0.0767 | 1.11× | | ||
| | MHA b1 h32/kv32 kv2048 | 0.0705 | 0.0752 | 1.07× | | ||
| | Qwen3-8B b4 h32/kv8 kv2048 | 0.0732 | 0.0784 | 1.07× | | ||
| | Qwen3-8B b8 h32/kv8 kv2048 | 0.0921 | 0.0818 | 0.89× | | ||
| | Qwen3-8B b16 h32/kv8 kv2048 | 0.1219 | 0.1024 | 0.84× | | ||
| | Qwen3-8B b32 h32/kv8 kv2048 | 0.2637 | 0.2725 | 1.03× | | ||
| | Qwen3-8B b64 h32/kv8 kv2048 | 0.4770 | 0.5337 | 1.12× | | ||
|
|
||
| Table source: `python bench_pa_performance.py --warmup 5 --iters 20` |
198 changes: 198 additions & 0 deletions
198
...map_and_ringbuffer/spmd_paged_attention_highperf/kernels/aic/paged_attention_highperf.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| #include <cstdint> | ||
|
|
||
| #ifdef __CPU_SIM | ||
| #include <pto/pto-inst.hpp> | ||
| #endif | ||
|
|
||
| #include "tensor.h" | ||
|
|
||
| #ifdef __CPU_SIM | ||
| #ifndef __gm__ | ||
| #define __gm__ | ||
| #endif | ||
| #ifndef __aicore__ | ||
| #define __aicore__ [aicore] // NOLINT(whitespace/braces) | ||
| #endif | ||
| #endif | ||
|
|
||
| #ifdef __CPU_SIM | ||
|
|
||
| #include <algorithm> | ||
| #include <cmath> | ||
| #include <cstring> | ||
|
|
||
| static float half_to_float(uint16_t h) { | ||
| uint32_t sign = static_cast<uint32_t>(h & 0x8000) << 16; | ||
| uint32_t exp = (h >> 10) & 0x1f; | ||
| uint32_t mant = h & 0x03ff; | ||
| uint32_t bits; | ||
| if (exp == 0) { | ||
| if (mant == 0) { | ||
| bits = sign; | ||
| } else { | ||
| exp = 1; | ||
| while ((mant & 0x0400) == 0) { | ||
| mant <<= 1; | ||
| --exp; | ||
| } | ||
| mant &= 0x03ff; | ||
| bits = sign | ((exp + 112) << 23) | (mant << 13); | ||
| } | ||
| } else if (exp == 31) { | ||
| bits = sign | 0x7f800000 | (mant << 13); | ||
| } else { | ||
| bits = sign | ((exp + 112) << 23) | (mant << 13); | ||
| } | ||
| float out; | ||
| std::memcpy(&out, &bits, sizeof(out)); | ||
| return out; | ||
| } | ||
|
|
||
| static uint16_t float_to_half(float f) { | ||
| uint32_t bits; | ||
| std::memcpy(&bits, &f, sizeof(bits)); | ||
| uint32_t sign = (bits >> 16) & 0x8000; | ||
| int32_t exp = static_cast<int32_t>((bits >> 23) & 0xff) - 127 + 15; | ||
| uint32_t mant = bits & 0x7fffff; | ||
| if (exp <= 0) { | ||
| if (exp < -10) return static_cast<uint16_t>(sign); | ||
| mant = (mant | 0x800000) >> (1 - exp); | ||
| return static_cast<uint16_t>(sign | ((mant + 0x1000) >> 13)); | ||
| } | ||
| if (exp >= 31) return static_cast<uint16_t>(sign | 0x7c00); | ||
| return static_cast<uint16_t>(sign | (static_cast<uint32_t>(exp) << 10) | ((mant + 0x1000) >> 13)); | ||
| } | ||
|
|
||
| extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { | ||
| #ifdef __DAV_VEC__ | ||
| (void)args; | ||
| return; | ||
| #else | ||
| auto *query_t = reinterpret_cast<Tensor *>(args[0]); | ||
| auto *key_t = reinterpret_cast<Tensor *>(args[1]); | ||
| auto *value_t = reinterpret_cast<Tensor *>(args[2]); | ||
| auto *block_table_t = reinterpret_cast<Tensor *>(args[3]); | ||
| auto *out_t = reinterpret_cast<Tensor *>(args[4]); | ||
|
|
||
| auto *query = reinterpret_cast<uint16_t *>(query_t->buffer.addr) + query_t->start_offset; | ||
| auto *key = reinterpret_cast<uint16_t *>(key_t->buffer.addr) + key_t->start_offset; | ||
| auto *value = reinterpret_cast<uint16_t *>(value_t->buffer.addr) + value_t->start_offset; | ||
| auto *block_table = reinterpret_cast<int32_t *>(block_table_t->buffer.addr) + block_table_t->start_offset; | ||
| auto *out = reinterpret_cast<uint16_t *>(out_t->buffer.addr) + out_t->start_offset; | ||
|
|
||
| const int batch = static_cast<int>(query_t->shapes[0]); | ||
| const int num_heads = static_cast<int>(query_t->shapes[1]); | ||
| const int head_dim = static_cast<int>(query_t->shapes[2]); | ||
| const int block_size = static_cast<int>(key_t->shapes[1]); | ||
| const int num_kv_heads = static_cast<int>(key_t->shapes[2]); | ||
| const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch; | ||
| const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]); | ||
| const int heads_per_kv = num_heads / num_kv_heads; | ||
| const int seq_len = blocks_per_batch * block_size; | ||
| const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim)); | ||
|
|
||
| for (int b = 0; b < batch; ++b) { | ||
| for (int h = 0; h < num_heads; ++h) { | ||
| const int kv_head = h / heads_per_kv; | ||
| float max_score = -INFINITY; | ||
| for (int token = 0; token < seq_len; ++token) { | ||
| const int block_col = std::min(token / block_size, max_blocks_per_query - 1); | ||
| const int block_id = block_table[b * max_blocks_per_query + block_col]; | ||
| const int block_token = token % block_size; | ||
| float score = 0.0f; | ||
| for (int d = 0; d < head_dim; ++d) { | ||
| const int q_idx = (b * num_heads + h) * head_dim + d; | ||
| const int k_idx = ((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + d; | ||
| score += half_to_float(query[q_idx]) * half_to_float(key[k_idx]); | ||
| } | ||
| max_score = std::max(max_score, score * scale); | ||
| } | ||
|
|
||
| float denom = 0.0f; | ||
| for (int d = 0; d < head_dim; ++d) { | ||
| float accum = 0.0f; | ||
| for (int token = 0; token < seq_len; ++token) { | ||
| const int block_col = std::min(token / block_size, max_blocks_per_query - 1); | ||
| const int block_id = block_table[b * max_blocks_per_query + block_col]; | ||
| const int block_token = token % block_size; | ||
| float score = 0.0f; | ||
| for (int kd = 0; kd < head_dim; ++kd) { | ||
| const int q_idx = (b * num_heads + h) * head_dim + kd; | ||
| const int k_idx = | ||
| ((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + kd; | ||
| score += half_to_float(query[q_idx]) * half_to_float(key[k_idx]); | ||
| } | ||
| const float weight = std::exp(score * scale - max_score); | ||
| if (d == 0) denom += weight; | ||
| const int v_idx = ((block_id * block_size + block_token) * num_kv_heads + kv_head) * head_dim + d; | ||
| accum += weight * half_to_float(value[v_idx]); | ||
| } | ||
| const int out_idx = (b * num_heads + h) * head_dim + d; | ||
| out[out_idx] = float_to_half(accum / denom); | ||
| } | ||
| } | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| #else | ||
|
|
||
| #define block_idx get_block_idx() | ||
| #define block_num get_block_num() | ||
| #define PTO_PA_NO_GLOBAL_ENTRY | ||
| #include "../kernel/pa_entry.cce" | ||
| #undef PTO_PA_NO_GLOBAL_ENTRY | ||
| #undef block_num | ||
| #undef block_idx | ||
|
|
||
| static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) { | ||
| __gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]); | ||
| return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr); | ||
| } | ||
|
Comment on lines
+148
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defensively check if static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
__gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
if (tensor == nullptr) {
return nullptr;
}
return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
} |
||
|
|
||
| extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { | ||
| __gm__ uint8_t *q_gm = tensor_data(args, 0); | ||
| __gm__ uint8_t *k_gm = tensor_data(args, 1); | ||
| __gm__ uint8_t *v_gm = tensor_data(args, 2); | ||
| __gm__ uint8_t *block_tables_gm = tensor_data(args, 3); | ||
| __gm__ uint8_t *o_gm = tensor_data(args, 4); | ||
| __gm__ uint8_t *s_gm = tensor_data(args, 5); | ||
| __gm__ uint8_t *p_gm = tensor_data(args, 6); | ||
| __gm__ uint8_t *o_tmp_gm = tensor_data(args, 7); | ||
| __gm__ uint8_t *go_gm = tensor_data(args, 8); | ||
| __gm__ uint8_t *o_core_tmp_gm = tensor_data(args, 9); | ||
| __gm__ uint8_t *l_gm = tensor_data(args, 10); | ||
| __gm__ uint8_t *gm_k16 = tensor_data(args, 11); | ||
| __gm__ uint8_t *gm_v16 = tensor_data(args, 12); | ||
| __gm__ uint8_t *tiling_para_gm = tensor_data(args, 13); | ||
| __gm__ uint8_t *null_gm = tensor_data(args, 14); | ||
|
|
||
| paged_attention_mask_body( | ||
| nullptr, | ||
| q_gm, | ||
| k_gm, | ||
| v_gm, | ||
| block_tables_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| null_gm, | ||
| o_gm, | ||
| s_gm, | ||
| p_gm, | ||
| o_tmp_gm, | ||
| go_gm, | ||
| o_core_tmp_gm, | ||
| l_gm, | ||
| gm_k16, | ||
| gm_v16, | ||
| tiling_para_gm | ||
| ); | ||
| } | ||
|
|
||
| #endif | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defensively validate input dimensions to prevent potential division-by-zero or undefined behavior in CPU simulation if
batch,num_kv_heads,head_dim, orblock_sizeare zero or negative.References