Skip to content

refactor(dsv4 decode rope): switch to gather/scatter from matmul-reassemble#425

Open
wangqin1723-max wants to merge 1 commit into
hw-native-sys:mainfrom
wangqin1723-max:perf/decode-rope-tweaks
Open

refactor(dsv4 decode rope): switch to gather/scatter from matmul-reassemble#425
wangqin1723-max wants to merge 1 commit into
hw-native-sys:mainfrom
wangqin1723-max:perf/decode-rope-tweaks

Conversation

@wangqin1723-max
Copy link
Copy Markdown
Contributor

@wangqin1723-max wangqin1723-max commented May 30, 2026

Summary

  • decode_qkv_proj_rope.py: Q/KV rope use gather mask-form + scatter index-form (replaces cube matmul reassemble + even_select_t/odd_select_t one-hot matrices). Q has FP32 stage + separate writeback scope to avoid the in-place GM-view write that trips intermittent AICPU 507046. T-chunked at 32 to fit 192 KB Vec UB at T=128.
  • decode_sparse_attn.py: rope de-interleave uses gather mask-form (replaces matmul + even/odd select tiles). Rope re-interleave keeps matmul — scatter index-form caused ~11% attn_out precision FAIL vs baseline (even with an explicit BF16 round-trip on the rotated values). ratio128's gather+scatter pattern is safe because its golden compares FP32; sparse_attn's golden ends in a BF16 cast which amplifies the scatter-vs-matmul rounding delta. Will revisit when scatter mask-form becomes available end-to-end.
  • Aligns the de-interleave path of both files with the compressor ratio4/128 gather pattern.

Known issues / follow-ups

  • Mask-form scatter (HW native TSCATTER<mask>) is preferred over index-form on decode_qkv_proj_rope.py (drops INT32 index tiles + col_expand; lowers AICPU dispatch load) but blocked by ptoas parser: pto.tscatter only parses 2-operand ins(%src, %index) form, not the mask-form ins(%src) {maskPattern=...}. Mask-form preserved locally via git stash; switch over after ptoas is bumped.
  • 507046 (AICPU stream sync timeout, issue [Bug] decode_compressor_ratio4 rmsnorm_rope cannot be converted to pl.spmd (wrong values or AICPU stall) #419) fires intermittently on the qkv_proj_rope scatter pattern (~1/3 PASS observed). Same root cause as the indexer / compressor scatter notes — not a regression introduced here. Re-run until pass.
  • sparse_attn rope re-interleave: scatter index-form yields a ~11% mismatch on attn_out at the rotated-value rounding boundary. Investigated pl.tensor.scatter lowering (multi-step tscatter + tcmps + tsel) — the precision drift isn't reproduced by ratio128's identical pattern, which suggests the BF16-cast downstream in sparse_attn (vs ratio128's FP32-only golden) amplifies a sub-ulp scatter rounding issue. Kept matmul re-interleave; further investigation tracked separately.

Test plan

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 30, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors RoPE rotation stages in DeepSeek QKV projection and sparse attention layers, replacing even/odd deinterleave–matmul–reinterleave pipelines with unified gather/rotate/scatter operations. New tiling constants control per-T-chunk processing; Q and KV paths now use gather indices and scatter buffers with fused normalization; sparse attention inverse RoPE merges rotated lanes into a single buffer and feeds that directly to output packing.

Changes

RoPE Rotation Refactoring

Layer / File(s) Summary
RoPE tiling constants and UB sizing
models/deepseek/v4/decode_qkv_proj_rope.py
Added Q_ROPE_T_TILE and KV_ROPE_T_TILE constants to control per-T-chunk buffering and keep working set within vector upper-bound constraints.
Q-path RoPE gather-rotate-scatter replacement
models/deepseek/v4/decode_qkv_proj_rope.py
Replaced pair-stage matmul reassembly with fused gather-rotate-scatter: extracts even/odd components via masks, applies per-chunk FP32 cos/sin rotation, scatters into q_rope_stage_fp32, and writes rotated ROPE slice back to q_flat in BF16.
KV-path RoPE CORE_GROUP gather-rotate-scatter fusion
models/deepseek/v4/decode_qkv_proj_rope.py
Replaced temporary-buffer reassembly with fused CORE_GROUP loop over T-chunks performing gather-rotate-scatter with inline gamma-scaled FP32 normalization and direct BF16 writeback to the kv ROPE slice.
Sparse attention inverse RoPE and output integration
models/deepseek/v4/decode_sparse_attn.py
Replaced inverse RoPE even/odd intermediate buffers with unified gather/rotate/scatter using mask-based lane extraction and per-row INT32 scatter indices into merged rope_full_buf; updated grouped packing to directly consume rotated ROPE from merged buffer.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#418: Both PRs rework RoPE lane reassembly to a gather/rotate/scatter-based flow (main PR in decode_qkv_proj_rope.py/decode_sparse_attn.py, retrieved PR in decode_compressor_ratio4.py), so the changes are tightly connected to the same RoPE transformation logic.

Poem

🐰 The old pairs have scattered far,
With gather-light and rotate's art,
Merged buffers glow FP32-bright,
Where even and odd dance as one,
RoPE flows smooth—the work is done! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main refactoring: switching from matmul-reassemble approach to gather/scatter patterns in DeepSeek v4 decode RoPE implementation, which aligns with the primary changes across both modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description comprehensively explains the refactoring from matmul-reassemble to gather/scatter patterns in both files, detailing technical changes, known issues, and test plans.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the RoPE (Rotary Position Embedding) rotation and inverse rotation logic in DeepSeek v4's decode kernels (decode_qkv_proj_rope.py and decode_sparse_attn.py). It replaces the previous multi-stage, matmul-based reassembly of interleaved even/odd lanes with a more efficient fused gather-rotate-scatter approach using hardware-native gather masks and scatter operations. Feedback focuses on performance optimizations to avoid redundant tensor allocations and initializations (pl.full) inside the inner loops by allocating the temporary buffers (q_rope_buf, kv_rope_buf, and r_buf) once outside the loops.

Comment on lines +225 to +227
q_rope_buf = pl.full([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Optimization: Allocating and initializing q_rope_buf with pl.full inside the tg_idx loop introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_DIM (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.

We can optimize this by allocating q_rope_buf once outside the loop (e.g., right after odd_idx_full on line 208) using q_rope_buf = pl.create_tensor([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32) and reusing it.

Suggested change
q_rope_buf = pl.full([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=even_idx_full, src=q_rot_even)
q_rope_buf = pl.tensor.scatter(q_rope_buf, dim=-1, index=odd_idx_full, src=q_rot_odd)

Comment on lines +308 to +310
kv_rope_buf = pl.full([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Optimization: Allocating and initializing kv_rope_buf with pl.full inside the tg_idx loop introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_DIM (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.

We can optimize this by allocating kv_rope_buf once outside the loop (e.g., right after odd_idx_full on line 290) using kv_rope_buf = pl.create_tensor([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32) and reusing it.

Suggested change
kv_rope_buf = pl.full([KV_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=0.0)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=even_idx_full, src=kv_rot_even)
kv_rope_buf = pl.tensor.scatter(kv_rope_buf, dim=-1, index=odd_idx_full, src=kv_rot_odd)

Comment on lines +257 to +259
r_buf = pl.full([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32, value=0.0)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Optimization: Allocating and initializing r_buf with pl.full inside the nested loops introduces unnecessary overhead in the kernel's inner loop. Since the subsequent scatter operations fully overwrite the entire ROPE_INTERLEAVE_TILE (as even_idx_full and odd_idx_full together cover all indices), we do not need to initialize the buffer with 0.0 on every iteration.

We can optimize this by allocating r_buf once outside the loops (e.g., right after odd_idx_full on line 240) using r_buf = pl.create_tensor([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32) and reusing it.

Suggested change
r_buf = pl.full([H, ROPE_INTERLEAVE_TILE], dtype=pl.FP32, value=0.0)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=even_idx_full, src=r_even_rot)
r_buf = pl.tensor.scatter(r_buf, dim=-1, index=odd_idx_full, src=r_odd_rot)

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
models/deepseek/v4/decode_sparse_attn.py (1)

87-88: 💤 Low value

Unused even_select_local / odd_select_local in sparse_attn
In models/deepseek/v4/decode_sparse_attn.py, sparse_attn still takes even_select_local and odd_select_local, but the function body never references them (inverse RoPE uses inline pl.tensor.gather(..., mask_pattern=pl.tile.MaskPattern.P0101/P1010) instead). They’re only defined/initialized and passed via sparse_attn_test/TensorSpec, so consider removing them in a follow-up if API compatibility doesn’t require keeping them.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_sparse_attn.py` around lines 87 - 88, The
parameters even_select_local and odd_select_local are unused in sparse_attn;
remove them from the sparse_attn function signature and from any
TensorSpec/initialization and calls (e.g., sparse_attn_test) so they are no
longer defined or passed; if API compatibility requires keeping the names,
instead forward a clear TODO comment and ensure the function actually uses them
(or document why unused). Update all callers and test fixtures that
constructed/expected these tensors (TensorSpec references) to stop providing
them.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@models/deepseek/v4/decode_sparse_attn.py`:
- Around line 87-88: The parameters even_select_local and odd_select_local are
unused in sparse_attn; remove them from the sparse_attn function signature and
from any TensorSpec/initialization and calls (e.g., sparse_attn_test) so they
are no longer defined or passed; if API compatibility requires keeping the
names, instead forward a clear TODO comment and ensure the function actually
uses them (or document why unused). Update all callers and test fixtures that
constructed/expected these tensors (TensorSpec references) to stop providing
them.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: d74a17be-823e-4842-bda9-cba5bcbef188

📥 Commits

Reviewing files that changed from the base of the PR and between 781c699 and 60a4cb3.

📒 Files selected for processing (2)
  • models/deepseek/v4/decode_qkv_proj_rope.py
  • models/deepseek/v4/decode_sparse_attn.py

@wangqin1723-max wangqin1723-max force-pushed the perf/decode-rope-tweaks branch from 60a4cb3 to 70e4ff6 Compare May 30, 2026 10:03
…semble

- decode_qkv_proj_rope.py: Q/KV rope use gather mask-form + scatter
  index-form. Q has FP32 stage + separate writeback scope to avoid the
  in-place GM-view write that trips intermittent AICPU 507046. T-chunked
  to 32 to fit 192 KB Vec UB at T=128. Aligns with compressor ratio4/128.
- decode_sparse_attn.py: rope de-interleave uses gather mask-form
  (replaces matmul + even/odd select). Re-interleave KEEPS matmul:
  scatter index-form caused ~11% attn_out precision FAIL vs baseline,
  even with an explicit BF16 round-trip on the rotated values. ratio128's
  gather+scatter is safe because its golden compares FP32; sparse_attn's
  golden ends in BF16 cast (cumulative scatter-vs-matmul rounding delta
  is amplified). Keep matmul re-interleave for now.

Mask-form scatter (HW native TSCATTER<mask>) is preferred over index-form
on qkv_proj_rope but blocked by ptoas tscatter parser; preserved in
`git stash` for later restore.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant