@@ -119,7 +119,7 @@ def test(
119119 print (f"--- Round { r + 1 } ---" )
120120
121121 # 1. 模拟调度与物理写入
122- new_lens_torch = torch .randint (max_step_len , max_step_len + 1 , (num_seqs ,), dtype = torch .int32 )
122+ new_lens_torch = torch .randint (1 , max_step_len + 1 , (num_seqs ,), dtype = torch .int32 )
123123 total_lens_list = []
124124 all_block_tables = []
125125
@@ -138,9 +138,9 @@ def test(
138138 v_new = torch .randn (cur_new_len , num_kv_heads , head_size )
139139 q_val = torch .randn (cur_new_len , num_heads , head_size )
140140
141- k_new = torch .ones_like (k_new )
142- v_new = torch .ones_like (v_new )
143- q_val = torch .ones_like (q_val )
141+ # k_new = torch.ones_like(k_new)
142+ # v_new = torch.ones_like(v_new)
143+ # q_val = torch.ones_like(q_val)
144144
145145 q_new_torch [i , :cur_new_len , :, :] = q_val
146146
@@ -155,9 +155,14 @@ def test(
155155 k_cache ._data_tensor .copy_ (k_cache ._torch_tensor )
156156 v_cache ._data_tensor .copy_ (v_cache ._torch_tensor )
157157
158- # 2. 准备算子 Tensor
158+ # 2. 准备 Q Tensor
159159 q_new = TestTensor .from_torch (q_new_torch , dtype , device )
160+
161+ # 3. 准备 out Tensor,确保初始值为 0
160162 out = TestTensor ((num_seqs , max_new_len , num_heads , head_size ), None , dtype , device )
163+ out .torch_tensor ().zero_ ()
164+ out ._data_tensor .zero_ ()
165+
161166 seq_lens = TestTensor .from_torch (torch .tensor (total_lens_list , dtype = torch .int32 ), InfiniDtype .I32 , device )
162167
163168 max_blocks = max (len (t ) for t in all_block_tables )
@@ -224,19 +229,8 @@ def test(
224229 # 5. 验证
225230 # ======================================================================
226231
227- print (f"[debug] ans: { ans [:, 0 , 0 , :5 ]} " )
228- print (f"[debug] out: { out .actual_tensor ()[:, 0 , 0 , :5 ]} " )
229-
230- diff = ans - out .actual_tensor ()
231- print (f"[debug] diff-shape: { diff .shape } " )
232- print (f"[debug] diff: { diff } " )
233-
234- print (f"[debug] max ans: { torch .max (ans )} " )
235- print (f"[debug] min ans: { torch .min (ans )} " )
236-
237- print (f"[debug] max out.actual_tensor(): { torch .max (out .actual_tensor ())} " )
238- print (f"[debug] min out.actual_tensor(): { torch .min (out .actual_tensor ())} " )
239-
232+ # print(f"[debug] ans: {ans[:, 0, 0, :5]}")
233+ # print(f"[debug] out: {out.actual_tensor()[:, 0, 0, :5]}")
240234
241235 atol , rtol = get_tolerance (_TOLERANCE_MAP , dtype )
242236 # compare out.actual_tensor() with reference result ans
@@ -251,9 +245,10 @@ def test(
251245# ==============================================================================
252246_TEST_CASES_ = [
253247 # (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len)
254- # (2, 8, 8, 128, 16, 32),
255- # (4, 16, 16, 64, 8, 64),
256- (2 , 1 , 1 , 128 , 8 , 16 ),
248+ (2 , 8 , 8 , 128 , 16 , 32 ),
249+ (4 , 16 , 16 , 128 , 8 , 64 ),
250+ (16 , 1 , 1 , 128 , 8 , 16 ),
251+ (1 , 1 , 1 , 128 , 8 , 16 ),
257252]
258253
259254_TENSOR_DTYPES = [InfiniDtype .F32 ]
0 commit comments