Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}

SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ template <typename OffsetVecType,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
bool kIsVTileFitsInPage,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
const index_t& stride_token,
Expand Down Expand Up @@ -64,7 +65,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
kv_offset_vec[k0] = page_base_offset;
});
}
else
else if constexpr(kIsVTileFitsInPage)
{
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
Expand All @@ -78,7 +79,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,

static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t token_idx_in_page =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
kInPageOffsetMask;

if constexpr(kKVMemoryLayout ==
Expand All @@ -103,6 +104,39 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
}
});
}
else
{
// page size > 1 and V tile spans multiple pages (e.g., page_size < kN0).
// Must compute page_id per token to avoid cross-page aliasing.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// page_idx is indexed by page_id for page_size > 1.

const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;

if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
// address pattern.
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);

kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_base_offset +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
});
}
}
}

Expand Down Expand Up @@ -144,15 +178,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
// For page_size < kN0 we must use per-token page lookup instead.
static constexpr bool kIsVTileFitsInPage = (kLog2PageSize > 0) && (kPageBlockSize % kN0 == 0);
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};

static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
"Page size must be 1, or a multiple of the tile size (kN0).");

static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
Expand Down Expand Up @@ -462,6 +495,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);

Expand Down Expand Up @@ -507,6 +541,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);

Expand Down Expand Up @@ -593,6 +628,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
Expand Down Expand Up @@ -767,6 +803,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
Expand Down Expand Up @@ -906,6 +943,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
Expand Down Expand Up @@ -963,6 +1001,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kIsVTileFitsInPage,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
Expand Down