refactor(dsv4 decode rope): switch to gather/scatter from matmul-reassemble#425
refactor(dsv4 decode rope): switch to gather/scatter from matmul-reassemble#425wangqin1723-max wants to merge 1 commit into
Conversation
📝 WalkthroughWalkthroughThis PR refactors RoPE rotation stages in DeepSeek QKV projection and sparse attention layers, replacing even/odd deinterleave–matmul–reinterleave pipelines with unified gather/rotate/scatter operations. New tiling constants control per-T-chunk processing; Q and KV paths now use gather indices and scatter buffers with fused normalization; sparse attention inverse RoPE merges rotated lanes into a single buffer and feeds that directly to output packing. ChangesRoPE Rotation Refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the RoPE (Rotary Position Embedding) rotation and inverse rotation logic in DeepSeek v4's decode kernels (decode_qkv_proj_rope.py and decode_sparse_attn.py). It replaces the previous multi-stage, matmul-based reassembly of interleaved even/odd lanes with a more efficient fused gather-rotate-scatter approach using hardware-native gather masks and scatter operations. Feedback focuses on performance optimizations to avoid redundant tensor allocations and initializations (pl.full) inside the inner loops by allocating the temporary buffers (q_rope_buf, kv_rope_buf, and r_buf) once outside the loops.
| q_rope_buf = pl.full([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0) | ||
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even) | ||
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd) |
There was a problem hiding this comment.
Performance Optimization: Allocating and initializing q_rope_buf with pl.full inside the tg_idx loop introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_DIM (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.
We can optimize this by allocating q_rope_buf once outside the loop (e.g., right after odd_idx_full on line 208) using q_rope_buf = pl.create_tensor([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32) and reusing it.
| q_rope_buf = pl.full([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0) | |
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even) | |
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd) | |
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even) | |
| q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd) |
| kv_rope_buf = pl.full([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0) | ||
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even) | ||
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd) |
There was a problem hiding this comment.
Performance Optimization: Allocating and initializing kv_rope_buf with pl.full inside the tg_idx loop introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_DIM (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.
We can optimize this by allocating kv_rope_buf once outside the loop (e.g., right after odd_idx_full on line 290) using kv_rope_buf = pl.create_tensor([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32) and reusing it.
| kv_rope_buf = pl.full([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0) | |
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even) | |
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd) | |
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even) | |
| kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd) |
| r_buf = pl.full([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32, value=0.0) | ||
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot) | ||
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot) |
There was a problem hiding this comment.
Performance Optimization: Allocating and initializing r_buf with pl.full inside the nested loops introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_INTERLEAVE_TILE (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.
We can optimize this by allocating r_buf once outside the loops (e.g., right after odd_idx_full on line 240) using r_buf = pl.create_tensor([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32) and reusing it.
| r_buf = pl.full([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32, value=0.0) | |
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot) | |
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot) | |
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot) | |
| r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot) |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
models/deepseek/v4/decode_sparse_attn.py (1)
87-88: 💤 Low valueUnused
even_select_local/odd_select_localinsparse_attn
Inmodels/deepseek/v4/decode_sparse_attn.py,sparse_attnstill takeseven_select_localandodd_select_local, but the function body never references them (inverse RoPE uses inlinepl.tensor.gather(..., mask_pattern=pl.tile.MaskPattern.P0101/P1010)instead). They’re only defined/initialized and passed viasparse_attn_test/TensorSpec, so consider removing them in a follow-up if API compatibility doesn’t require keeping them.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/decode_sparse_attn.py` around lines 87 - 88, The parameters even_select_local and odd_select_local are unused in sparse_attn; remove them from the sparse_attn function signature and from any TensorSpec/initialization and calls (e.g., sparse_attn_test) so they are no longer defined or passed; if API compatibility requires keeping the names, instead forward a clear TODO comment and ensure the function actually uses them (or document why unused). Update all callers and test fixtures that constructed/expected these tensors (TensorSpec references) to stop providing them.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/deepseek/v4/decode_sparse_attn.py`:
- Around line 87-88: The parameters even_select_local and odd_select_local are
unused in sparse_attn; remove them from the sparse_attn function signature and
from any TensorSpec/initialization and calls (e.g., sparse_attn_test) so they
are no longer defined or passed; if API compatibility requires keeping the
names, instead forward a clear TODO comment and ensure the function actually
uses them (or document why unused). Update all callers and test fixtures that
constructed/expected these tensors (TensorSpec references) to stop providing
them.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: d74a17be-823e-4842-bda9-cba5bcbef188
📒 Files selected for processing (2)
models/deepseek/v4/decode_qkv_proj_rope.pymodels/deepseek/v4/decode_sparse_attn.py
60a4cb3 to
70e4ff6
Compare
…semble - decode_qkv_proj_rope.py: Q/KV rope use gather mask-form + scatter index-form. Q has FP32 stage + separate writeback scope to avoid the in-place GM-view write that trips intermittent AICPU 507046. T-chunked to 32 to fit 192 KB Vec UB at T=128. Aligns with compressor ratio4/128. - decode_sparse_attn.py: rope de-interleave uses gather mask-form (replaces matmul + even/odd select). Re-interleave KEEPS matmul: scatter index-form caused ~11% attn_out precision FAIL vs baseline, even with an explicit BF16 round-trip on the rotated values. ratio128's gather+scatter is safe because its golden compares FP32; sparse_attn's golden ends in BF16 cast (cumulative scatter-vs-matmul rounding delta is amplified). Keep matmul re-interleave for now. Mask-form scatter (HW native TSCATTER<mask>) is preferred over index-form on qkv_proj_rope but blocked by ptoas tscatter parser; preserved in `git stash` for later restore.
Summary
decode_qkv_proj_rope.py: Q/KV rope use gather mask-form + scatter index-form (replaces cube matmul reassemble +even_select_t/odd_select_tone-hot matrices). Q has FP32 stage + separate writeback scope to avoid the in-place GM-view write that trips intermittent AICPU 507046. T-chunked at 32 to fit 192 KB Vec UB at T=128.decode_sparse_attn.py: rope de-interleave uses gather mask-form (replaces matmul + even/odd select tiles). Rope re-interleave keeps matmul — scatter index-form caused ~11%attn_outprecision FAIL vs baseline (even with an explicit BF16 round-trip on the rotated values). ratio128's gather+scatter pattern is safe because its golden compares FP32; sparse_attn's golden ends in a BF16 cast which amplifies the scatter-vs-matmul rounding delta. Will revisit when scatter mask-form becomes available end-to-end.Known issues / follow-ups
TSCATTER<mask>) is preferred over index-form ondecode_qkv_proj_rope.py(drops INT32 index tiles + col_expand; lowers AICPU dispatch load) but blocked by ptoas parser:pto.tscatteronly parses 2-operandins(%src, %index)form, not the mask-formins(%src) {maskPattern=...}. Mask-form preserved locally viagit stash; switch over after ptoas is bumped.attn_outat the rotated-value rounding boundary. Investigatedpl.tensor.scatterlowering (multi-step tscatter + tcmps + tsel) — the precision drift isn't reproduced by ratio128's identical pattern, which suggests the BF16-cast downstream in sparse_attn (vs ratio128's FP32-only golden) amplifies a sub-ulp scatter rounding issue. Kept matmul re-interleave; further investigation tracked separately.Test plan
decode_qkv_proj_rope.py -p a2a3 -d <id>— precision PASS (q/kv/qr/qr_scale), 1/3 PASS rate due to [Bug] decode_compressor_ratio4 rmsnorm_rope cannot be converted to pl.spmd (wrong values or AICPU stall) #419decode_sparse_attn.py -p a2a3 -d <id>—attn_outPASS 3/3git stashand re-verify