Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}

SUPPORTED_PAGE_SIZE = [128, 256, 1024]
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
Expand Down Expand Up @@ -737,6 +737,8 @@ def get_fwd_blobs(

# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ template <typename OffsetVecType,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
const index_t& stride_kv,
const index_t& page_stride_kv,
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
const index_t& stride_token,
const index_t& stride_page_block,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec,
index_t global_seq_offset = 0)
Expand All @@ -39,47 +39,70 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t page_offset = global_token_idx & kInPageOffsetMask;
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
static_cast<long_index_t>(page_offset) * stride_kv;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
kv_offset_vec[k0] = static_cast<long_index_t>(page_idx[page_id]) * stride_page_block +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
});
}
else
{
// for v offsets
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;

const long_index_t page_loc =
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
if constexpr(kLog2PageSize == 0 &&
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
// page size = 1, per-token page lookup.
// Here page_idx maps token_idx -> physical_page_id, so global_seq_offset must be
// the absolute token index within the batch's kv_page_indices slice.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;

static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t page_offset =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
kInPageOffsetMask;
const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[global_token_idx]) * stride_page_block;

if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
const index_t s = page_offset;
const index_t D = stride_kv;
kv_offset_vec[k0] = page_base_offset;
});
}
else
{
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;

const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;

static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t token_idx_in_page =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
kInPageOffsetMask;

if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
// Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
// (token_idx_in_page % kVectorSize)

const long_index_t s_offset =
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
(s % kVectorSize);
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);

kv_offset_vec[k0] = page_loc + s_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
}
});
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_base_offset +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
});
}
}
}

Expand Down Expand Up @@ -127,9 +150,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr auto I3 = number<3>{};

static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kPageBlockSize % kN0 == 0,
"V offset assumes each tile stays within a page; kPageBlockSize must be "
"divisible by kN0.");
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
"Page size must be 1, or a multiple of the tile size (kN0).");

static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
Expand Down