High performance Paged Attention example#899
Conversation
|
Warning Review limit reached
More reviews will be available in 56 minutes and 38 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (13)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
a305ea5 to
4f2365b
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance SPMD paged attention implementation, including C++ kernels, tiling logic, benchmarks, and correctness tests, alongside compiler updates to dynamically include CANN directories. The review feedback recommends adding defensive input validation in the tiling and kernel code to prevent division-by-zero and null pointer dereferences, checking environment variables in the compilation script, and adding future annotations in the test script.
| Returns: | ||
| (tiling_tensor, effective_block_dim) | ||
| """ | ||
| kv_real = kv_heads if kv_heads > 0 else num_heads |
There was a problem hiding this comment.
Always perform defensive validation on input parameters before deriving other dependent variables from them to prevent potential division-by-zero or out-of-bounds errors (e.g., if batch, block_dim, num_heads, or block_size are zero or negative).
| kv_real = kv_heads if kv_heads > 0 else num_heads | |
| if batch <= 0 or block_dim <= 0 or num_heads <= 0 or block_size <= 0 or head_dim <= 0 or head_dim_v <= 0: | |
| raise ValueError("Input dimensions (batch, block_dim, num_heads, block_size, head_dim, head_dim_v) must be strictly positive.") | |
| kv_real = kv_heads if kv_heads > 0 else num_heads |
References
- Always perform defensive validation and normalization/fixups on input parameters before deriving other dependent variables from them, even if the invalid input is theoretically unreachable in practice.
| const int batch = static_cast<int>(query_t->shapes[0]); | ||
| const int num_heads = static_cast<int>(query_t->shapes[1]); | ||
| const int head_dim = static_cast<int>(query_t->shapes[2]); | ||
| const int block_size = static_cast<int>(key_t->shapes[1]); | ||
| const int num_kv_heads = static_cast<int>(key_t->shapes[2]); | ||
| const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch; | ||
| const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]); | ||
| const int heads_per_kv = num_heads / num_kv_heads; | ||
| const int seq_len = blocks_per_batch * block_size; | ||
| const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim)); |
There was a problem hiding this comment.
Defensively validate input dimensions to prevent potential division-by-zero or undefined behavior in CPU simulation if batch, num_kv_heads, head_dim, or block_size are zero or negative.
const int batch = static_cast<int>(query_t->shapes[0]);
const int num_heads = static_cast<int>(query_t->shapes[1]);
const int head_dim = static_cast<int>(query_t->shapes[2]);
const int block_size = static_cast<int>(key_t->shapes[1]);
const int num_kv_heads = static_cast<int>(key_t->shapes[2]);
if (batch <= 0 || num_kv_heads <= 0 || head_dim <= 0 || block_size <= 0) {
return;
}
const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch;
const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]);
const int heads_per_kv = num_heads / num_kv_heads;
const int seq_len = blocks_per_batch * block_size;
const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));References
- Always perform defensive validation and normalization/fixups on input parameters before deriving other dependent variables from them, even if the invalid input is theoretically unreachable in practice.
| static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) { | ||
| __gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]); | ||
| return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr); | ||
| } |
There was a problem hiding this comment.
Defensively check if tensor is null before dereferencing tensor->buffer.addr to prevent potential null pointer dereferences.
static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
__gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
if (tensor == nullptr) {
return nullptr;
}
return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
}| #!/usr/bin/env bash | ||
| set -euo pipefail | ||
|
|
||
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
|
|
There was a problem hiding this comment.
Defensively check if ASCEND_TOOLKIT_HOME is set before executing bisheng to provide a clear and actionable error message.
| #!/usr/bin/env bash | |
| set -euo pipefail | |
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| #!/usr/bin/env bash | |
| set -euo pipefail | |
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| if [ -z "${ASCEND_TOOLKIT_HOME:-}" ]; then | |
| echo "Error: ASCEND_TOOLKIT_HOME environment variable is not set." >&2 | |
| exit 1 | |
| fi |
| #!/usr/bin/env python3 | ||
| """High-performance SPMD paged attention using the scene-test calling interface.""" | ||
|
|
There was a problem hiding this comment.
Include from __future__ import annotations at the top of the file to enable PEP 585 generic collections (like tuple[...]) and avoid runtime errors on Python versions earlier than 3.10.
| #!/usr/bin/env python3 | |
| """High-performance SPMD paged attention using the scene-test calling interface.""" | |
| #!/usr/bin/env python3 | |
| """High-performance SPMD paged attention using the scene-test calling interface.""" | |
| from __future__ import annotations | |
References
- In projects targeting Python versions earlier than 3.10 (such as Python 3.9), include 'from future import annotations' to enable the use of PEP 604 union type hints (e.g., 'int | None') and avoid runtime errors.
4f2365b to
621782f
Compare
No description provided.