Skip to content

High performance Paged Attention example#899

Open
MirkoDeVita98 wants to merge 1 commit into
hw-native-sys:mainfrom
MirkoDeVita98:pr-655-work
Open

High performance Paged Attention example#899
MirkoDeVita98 wants to merge 1 commit into
hw-native-sys:mainfrom
MirkoDeVita98:pr-655-work

Conversation

@MirkoDeVita98
Copy link
Copy Markdown

No description provided.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 29, 2026

Review Change Stack

Warning

Review limit reached

@MirkoDeVita98, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 56 minutes and 38 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 7b1bd167-c740-4710-8872-7301c29e27a4

📥 Commits

Reviewing files that changed from the base of the PR and between 22538de and 621782f.

📒 Files selected for processing (13)
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/.gitignore
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/README.md
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/aic/paged_attention_highperf.cpp
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/bench_pa_performance.py
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/compile.sh
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/kernel/pa_entry.cce
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/kernel/pa_kernel.cce
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/orchestration/paged_attention_highperf_orch.cpp
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/pa_accuracy.py
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/pa_tiling.py
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/paged_attention_wrapper.cpp
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernels/tiling/pa_tiling_struct.h
  • tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/test_spmd_paged_attention_highperf.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance SPMD paged attention implementation, including C++ kernels, tiling logic, benchmarks, and correctness tests, alongside compiler updates to dynamically include CANN directories. The review feedback recommends adding defensive input validation in the tiling and kernel code to prevent division-by-zero and null pointer dereferences, checking environment variables in the compilation script, and adding future annotations in the test script.

Returns:
(tiling_tensor, effective_block_dim)
"""
kv_real = kv_heads if kv_heads > 0 else num_heads
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Always perform defensive validation on input parameters before deriving other dependent variables from them to prevent potential division-by-zero or out-of-bounds errors (e.g., if batch, block_dim, num_heads, or block_size are zero or negative).

Suggested change
kv_real = kv_heads if kv_heads > 0 else num_heads
if batch <= 0 or block_dim <= 0 or num_heads <= 0 or block_size <= 0 or head_dim <= 0 or head_dim_v <= 0:
raise ValueError("Input dimensions (batch, block_dim, num_heads, block_size, head_dim, head_dim_v) must be strictly positive.")
kv_real = kv_heads if kv_heads > 0 else num_heads
References
  1. Always perform defensive validation and normalization/fixups on input parameters before deriving other dependent variables from them, even if the invalid input is theoretically unreachable in practice.

Comment on lines +83 to +92
const int batch = static_cast<int>(query_t->shapes[0]);
const int num_heads = static_cast<int>(query_t->shapes[1]);
const int head_dim = static_cast<int>(query_t->shapes[2]);
const int block_size = static_cast<int>(key_t->shapes[1]);
const int num_kv_heads = static_cast<int>(key_t->shapes[2]);
const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch;
const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]);
const int heads_per_kv = num_heads / num_kv_heads;
const int seq_len = blocks_per_batch * block_size;
const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively validate input dimensions to prevent potential division-by-zero or undefined behavior in CPU simulation if batch, num_kv_heads, head_dim, or block_size are zero or negative.

    const int batch = static_cast<int>(query_t->shapes[0]);
    const int num_heads = static_cast<int>(query_t->shapes[1]);
    const int head_dim = static_cast<int>(query_t->shapes[2]);
    const int block_size = static_cast<int>(key_t->shapes[1]);
    const int num_kv_heads = static_cast<int>(key_t->shapes[2]);
    if (batch <= 0 || num_kv_heads <= 0 || head_dim <= 0 || block_size <= 0) {
        return;
    }
    const int blocks_per_batch = static_cast<int>(key_t->shapes[0]) / batch;
    const int max_blocks_per_query = static_cast<int>(block_table_t->shapes[1]);
    const int heads_per_kv = num_heads / num_kv_heads;
    const int seq_len = blocks_per_batch * block_size;
    const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
References
  1. Always perform defensive validation and normalization/fixups on input parameters before deriving other dependent variables from them, even if the invalid input is theoretically unreachable in practice.

Comment on lines +148 to +151
static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
__gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively check if tensor is null before dereferencing tensor->buffer.addr to prevent potential null pointer dereferences.

static __aicore__ __attribute__((always_inline)) __gm__ uint8_t *tensor_data(__gm__ int64_t *args, int idx) {
    __gm__ Tensor *tensor = reinterpret_cast<__gm__ Tensor *>(args[idx]);
    if (tensor == nullptr) {
        return nullptr;
    }
    return reinterpret_cast<__gm__ uint8_t *>(tensor->buffer.addr);
}

Comment on lines +1 to +5
#!/usr/bin/env bash
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively check if ASCEND_TOOLKIT_HOME is set before executing bisheng to provide a clear and actionable error message.

Suggested change
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
if [ -z "${ASCEND_TOOLKIT_HOME:-}" ]; then
echo "Error: ASCEND_TOOLKIT_HOME environment variable is not set." >&2
exit 1
fi

Comment on lines +1 to +3
#!/usr/bin/env python3
"""High-performance SPMD paged attention using the scene-test calling interface."""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Include from __future__ import annotations at the top of the file to enable PEP 585 generic collections (like tuple[...]) and avoid runtime errors on Python versions earlier than 3.10.

Suggested change
#!/usr/bin/env python3
"""High-performance SPMD paged attention using the scene-test calling interface."""
#!/usr/bin/env python3
"""High-performance SPMD paged attention using the scene-test calling interface."""
from __future__ import annotations
References
  1. In projects targeting Python versions earlier than 3.10 (such as Python 3.9), include 'from future import annotations' to enable the use of PEP 604 union type hints (e.g., 'int | None') and avoid runtime errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant