Skip to content

Commit ac19252

Browse files
committed
fix: n seqs and random length test pass
1 parent 96262fb commit ac19252

File tree

2 files changed

+35
-38
lines changed

2 files changed

+35
-38
lines changed

src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,32 @@ __global__ void pagedAttentionPrefillKernel(
6262

6363
const int32_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
6464

65-
// Q ptr: [seq, new_len, head, dim]
65+
// 假设 q_stride 传入的是单个 Sequence 在内存中占据的 Tdata 数量 (即 max_new_len * num_heads * head_size)
6666
const Tdata *q_ptr_base = q_ + seq_idx * q_stride +
67-
(q_token_idx * num_heads + head_idx) * head_size_const;
67+
q_token_idx * (num_heads * head_size_const) +
68+
head_idx * head_size_const;
6869

69-
// Out ptr
70+
// --- 2. 修改 Out 的基地址计算 ---
7071
Tdata *out_ptr = out_ + seq_idx * o_stride +
71-
(q_token_idx * num_heads + head_idx) * head_size_const;
72+
q_token_idx * (num_heads * head_size_const) +
73+
head_idx * head_size_const;
7274

7375
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
7476

75-
// 只让第一个 Sequence, 第一个 Token, 第一个 Head 的第一个线程执行打印
76-
if (seq_idx == 0 && q_token_idx == 0 && head_idx == 0 && dim_idx == 0) {
77-
printf("DEBUG: Scale=%f, HeadSize=%zu, BlockSize=%zu\n", scale, head_size_const, block_size);
77+
// // 只让第一个 Sequence, 第一个 Token, 第一个 Head 的第一个线程执行打印
78+
// if (seq_idx == 0 && q_token_idx == 0 && head_idx == 0 && dim_idx == 0) {
79+
// printf("DEBUG: Scale=%f, HeadSize=%zu, BlockSize=%zu\n", scale, head_size_const, block_size);
7880

79-
// 检查 Q 的前 5 个元素
80-
for(int i=0; i<5; ++i) printf("Q[%d]=%f ", i, (float)q_ptr_base[i]);
81-
printf("\n");
82-
83-
// 检查第一个 KV Block 的前 5 个元素
84-
const int32_t first_physical_block = block_table[0];
85-
const Tdata *first_k = k_cache_ + first_physical_block * kv_block_stride;
86-
for(int i=0; i<5; ++i) printf("K_cache[0][%d]=%f ", i, (float)first_k[i]);
87-
printf("\n");
88-
}
81+
// // 检查 Q 的前 5 个元素
82+
// for(int i=0; i<5; ++i) printf("Q[%d]=%f ", i, (float)q_ptr_base[i]);
83+
// printf("\n");
84+
85+
// // 检查第一个 KV Block 的前 5 个元素
86+
// const int32_t first_physical_block = block_table[0];
87+
// const Tdata *first_k = k_cache_ + first_physical_block * kv_block_stride;
88+
// for(int i=0; i<5; ++i) printf("K_cache[0][%d]=%f ", i, (float)first_k[i]);
89+
// printf("\n");
90+
// }
8991

9092
// --- Pass 1: Find Global Max ---
9193
Tcompute max_score = -FLT_MAX;

test/infiniop/paged_attention_prefill.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test(
119119
print(f"--- Round {r+1} ---")
120120

121121
# 1. 模拟调度与物理写入
122-
new_lens_torch = torch.randint(max_step_len, max_step_len + 1, (num_seqs,), dtype=torch.int32)
122+
new_lens_torch = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32)
123123
total_lens_list = []
124124
all_block_tables = []
125125

@@ -138,9 +138,9 @@ def test(
138138
v_new = torch.randn(cur_new_len, num_kv_heads, head_size)
139139
q_val = torch.randn(cur_new_len, num_heads, head_size)
140140

141-
k_new = torch.ones_like(k_new)
142-
v_new = torch.ones_like(v_new)
143-
q_val = torch.ones_like(q_val)
141+
# k_new = torch.ones_like(k_new)
142+
# v_new = torch.ones_like(v_new)
143+
# q_val = torch.ones_like(q_val)
144144

145145
q_new_torch[i, :cur_new_len, :, :] = q_val
146146

@@ -155,9 +155,14 @@ def test(
155155
k_cache._data_tensor.copy_(k_cache._torch_tensor)
156156
v_cache._data_tensor.copy_(v_cache._torch_tensor)
157157

158-
# 2. 准备算子 Tensor
158+
# 2. 准备 Q Tensor
159159
q_new = TestTensor.from_torch(q_new_torch, dtype, device)
160+
161+
# 3. 准备 out Tensor,确保初始值为 0
160162
out = TestTensor((num_seqs, max_new_len, num_heads, head_size), None, dtype, device)
163+
out.torch_tensor().zero_()
164+
out._data_tensor.zero_()
165+
161166
seq_lens = TestTensor.from_torch(torch.tensor(total_lens_list, dtype=torch.int32), InfiniDtype.I32, device)
162167

163168
max_blocks = max(len(t) for t in all_block_tables)
@@ -224,19 +229,8 @@ def test(
224229
# 5. 验证
225230
# ======================================================================
226231

227-
print(f"[debug] ans: {ans[:, 0, 0, :5]}")
228-
print(f"[debug] out: {out.actual_tensor()[:, 0, 0, :5]}")
229-
230-
diff = ans - out.actual_tensor()
231-
print(f"[debug] diff-shape: {diff.shape}")
232-
print(f"[debug] diff: {diff}")
233-
234-
print(f"[debug] max ans: {torch.max(ans)}")
235-
print(f"[debug] min ans: {torch.min(ans)}")
236-
237-
print(f"[debug] max out.actual_tensor(): {torch.max(out.actual_tensor())}")
238-
print(f"[debug] min out.actual_tensor(): {torch.min(out.actual_tensor())}")
239-
232+
# print(f"[debug] ans: {ans[:, 0, 0, :5]}")
233+
# print(f"[debug] out: {out.actual_tensor()[:, 0, 0, :5]}")
240234

241235
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
242236
# compare out.actual_tensor() with reference result ans
@@ -251,9 +245,10 @@ def test(
251245
# ==============================================================================
252246
_TEST_CASES_ = [
253247
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len)
254-
# (2, 8, 8, 128, 16, 32),
255-
# (4, 16, 16, 64, 8, 64),
256-
(2, 1, 1, 128, 8, 16),
248+
(2, 8, 8, 128, 16, 32),
249+
(4, 16, 16, 128, 8, 64),
250+
(16, 1, 1, 128, 8, 16),
251+
(1, 1, 1, 128, 8, 16),
257252
]
258253

259254
_TENSOR_DTYPES = [InfiniDtype.F32]

0 commit comments

Comments
 (0)