Skip to content

Refactor(dsv4): full dynamic-shape for compressor_ratio128 #410#412

Open
bumble0918 wants to merge 1 commit into
hw-native-sys:mainfrom
bumble0918:feature/2026-05-29
Open

Refactor(dsv4): full dynamic-shape for compressor_ratio128 #410#412
bumble0918 wants to merge 1 commit into
hw-native-sys:mainfrom
bumble0918:feature/2026-05-29

Conversation

@bumble0918
Copy link
Copy Markdown
Contributor

  • Replace static dims (B, S, BLOCK_NUM, block-table sizes) with pl.dynamic(); derive all runtime sizes via pl.tensor.dim()
  • Convert all compute scopes (kv_score_proj, state_scatter, softmax_pool, rmsnorm_rope, kv_finalize) from pl.spmd(static B) to pl.parallel(dynamic b_dim) + inner pl.spmd/pl.at
  • Add bind_dynamic() for all new DynVars in compressor_test

…sys#410

- Replace static dims (B, S, BLOCK_NUM, block-table sizes) with
  pl.dynamic(); derive all runtime sizes via pl.tensor.dim()
- Convert all compute scopes (kv_score_proj, state_scatter,
  softmax_pool, rmsnorm_rope, kv_finalize) from pl.spmd(static
  B) to pl.parallel(dynamic b_dim) + inner pl.spmd/pl.at
- Add bind_dynamic() for all new DynVars in compressor_test
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 28, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

The decode_compressor_ratio128.py kernel is refactored to use PyPTO dynamic shape variables instead of static compile-time constants for batch, sequence, and paging dimensions, enabling the kernel to accept variable-shaped tensors at runtime. The test harness is updated with matching dynamic signatures and dimension bindings.

Changes

DeepSeek Compressor Dynamic Shape Migration

Layer / File(s) Summary
Dynamic Shape Dimension Contract
models/deepseek/v4/decode_compressor_ratio128.py
Introduces dynamic-shape PyPTO symbols (B_DYN, S_DYN, COMPRESS_STATE_*_DYN, CMP_*_DYN) that parameterize the compressor and test signatures.
Compressor Kernel Initialization and State Scatter
models/deepseek/v4/decode_compressor_ratio128.py
Updates compressor kernel signature to accept dynamic tensors, derives runtime dimensions (b_dim, s_dim, paging sizes), allocates scratch buffers using dynamic sizes, and rewrites state-scatter pre-phase to iterate over runtime extents instead of static constants.
Compressor Pooling, Normalization, Writeback, and Test Validation
models/deepseek/v4/decode_compressor_ratio128.py
Reworks softmax pooling across state slots, fused RMSNorm + RoPE (including even/odd KV gathering and scatter), and cache writeback using dynamic iteration/indexing. Updates compressor_test signature to accept and bind all dynamic tensor dimensions via bind_dynamic calls.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

🐰 A kernel once frozen in compile-time stone,
Now dances with shapes that are barely known!
From static to dynamic, the dimensions take flight,
While pooling and normalizing—all baked just right!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Refactor(dsv4): full dynamic-shape for compressor_ratio128 #410' clearly summarizes the main change—converting the compressor_ratio128 module to use full dynamic shapes instead of static compile-time dimensions.
Description check ✅ Passed The description is directly related to the changeset, detailing the specific refactoring changes: replacing static dims with dynamic ones, converting compute scopes, and adding bind_dynamic() calls in tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic shape support to the DeepSeek v4 decode compressor by replacing static dimensions with dynamic variables and binding them in the test function. While the transition to dynamic shapes is a great improvement, the current implementation introduces several critical bugs where integer division in parallel loops (such as bs // B_TILE and b_dim // RMS_TILE) will truncate to zero for small batch sizes or sequence lengths, causing essential operations like projection, RMSNorm, RoPE, and KV finalization to be skipped entirely. Additionally, using pl.range with the dynamic variable b_dim may lead to compilation or optimization issues and should be replaced with a dynamic-friendly construct like pl.parallel.

kv_proj_scratch = pl.create_tensor([bs, OUT_DIM], dtype=pl.FP32)
score_proj_scratch = pl.create_tensor([bs, OUT_DIM], dtype=pl.FP32)

for row_idx in pl.parallel(0, bs // B_TILE):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The loop pl.parallel(0, bs // B_TILE) uses integer division which truncates the loop range. If bs (which is b_dim * s_dim) is less than B_TILE (64), bs // B_TILE evaluates to 0, and the projection loop is completely skipped. For example, with s_dim = 2 (MTP) and a dynamic batch size b_dim < 32, bs will be less than 64, resulting in no projection being computed at all. To support full dynamic shapes, you must handle cases where bs is not a multiple of B_TILE, typically by using ceiling division and adding boundary checks inside the loop to avoid out-of-bounds accesses on x_flat.

# batches per spmd block so all vec col-vectors hit ptoas's 32B-aligned
# row stride without per-batch row padding.
for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="rmsnorm_rope"):
for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop pl.parallel(0, b_dim // RMS_TILE) assumes that the dynamic batch dimension b_dim is always a multiple of RMS_TILE (which is 8). If b_dim is not a multiple of RMS_TILE (for example, if b_dim < 8 during dynamic batching), the integer division b_dim // RMS_TILE will truncate to 0, causing the remaining batches to be completely skipped during RMSNorm, RoPE, and KV finalization. To support truly arbitrary dynamic shapes, you should either assert/document that b_dim must be a multiple of RMS_TILE, or use ceiling division pl.ceil_div(b_dim, RMS_TILE) and add boundary guards (e.g., if global_c_idx < b_dim:) inside the loops to prevent out-of-bounds accesses when slicing or reading tensors.

# carries (kv_flat / cmp_kv_cache_flat / compress_state_flat) on one task,
# which hits pypto #1573 (orchestration phi cross-assignment).
for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="kv_finalize"):
for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the RMSNorm loop, pl.parallel(0, b_dim // RMS_TILE) for kv_finalize will truncate to 0 if b_dim < 8, skipping the finalization of KV cache updates entirely. Consider using ceiling division combined with boundary checks to ensure correctness for all dynamic batch sizes.


with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_scatter_pre"):
for global_c_idx in pl.range(B):
for global_c_idx in pl.range(b_dim):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using pl.range(b_dim) with a dynamic variable b_dim may fail compilation or prevent loop optimization/unrolling in the compiler, as pl.range typically expects a compile-time constant. Consider using pl.parallel(0, b_dim) or another dynamic-friendly loop construct instead.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/decode_compressor_ratio128.py`:
- Around line 177-213: The loop over batches uses b_dim // RMS_TILE which drops
a trailing partial batch; change the batching to use a ceiling division (e.g.
iterate over math.ceil(b_dim / RMS_TILE)) and compute end = min(batch_base +
RMS_TILE, b_dim) and curr_tile = end - batch_base, then replace fixed RMS_TILE
uses inside the loop with curr_tile for slices and buffer shapes (affecting
normed_kv, cos_b, sin_b, partial_sq, idx_target, rope_buf, pooled_kv gathers,
and any pl.full/pl.reshape sizes) so the last partial batch is correctly
processed; keep all logic (RMSNorm and RoPE steps) the same but keyed to
curr_tile.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 715fec11-05b9-4d84-9ab6-8294474c6645

📥 Commits

Reviewing files that changed from the base of the PR and between 64d2e13 and d7bb8bf.

📒 Files selected for processing (1)
  • models/deepseek/v4/decode_compressor_ratio128.py

Comment on lines +177 to +213
normed_kv = pl.create_tensor([b_dim, HEAD_DIM], dtype=pl.FP32)

# Fused RMSNorm + gather/scatter-based RoPE, processing RMS_TILE real
# batches per spmd block so all vec col-vectors hit ptoas's 32B-aligned
# row stride without per-batch row padding.
for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="rmsnorm_rope"):
for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE):
batch_base = batch_base_idx * RMS_TILE
cos_b = cos[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
sin_b = sin[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
partial_sq = pl.full([1, RMS_TILE], dtype=pl.FP32, value=0.0)
for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2):
kv_rms_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
kv_rms_sq = pl.mul(kv_rms_chunk, kv_rms_chunk)
kv_rms_rowsum = pl.reshape(pl.row_sum(kv_rms_sq), [1, RMS_TILE])
partial_sq = pl.add(partial_sq, kv_rms_rowsum)

variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [RMS_TILE, 1])
inv_rms = pl.recip(pl.sqrt(variance))
for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2):
kv_norm_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
gamma = norm_w_2d[:, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma)
normed_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] = normed_chunk

kv_rope = normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM]
even_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P0101)
odd_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P1010)
rope_even = pl.sub(pl.mul(even_tile, cos_b), pl.mul(odd_tile, sin_b))
rope_odd = pl.add(pl.mul(even_tile, sin_b), pl.mul(odd_tile, cos_b))
idx_target = pl.full([RMS_TILE, ROPE_HEAD_DIM // 2], dtype=pl.INT32, value=0)
even_idx_full = pl.col_expand(idx_target, even_idx)
odd_idx_full = pl.col_expand(idx_target, odd_idx)
rope_buf = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=0.0)
rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=even_idx_full, src=rope_even)
rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=odd_idx_full, src=rope_odd)
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_buf
with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm_rope"):
cos_b = cos[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
sin_b = sin[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
partial_sq = pl.full([1, RMS_TILE], dtype=pl.FP32, value=0.0)
for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2):
kv_rms_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
kv_rms_sq = pl.mul(kv_rms_chunk, kv_rms_chunk)
kv_rms_rowsum = pl.reshape(pl.row_sum(kv_rms_sq), [1, RMS_TILE])
partial_sq = pl.add(partial_sq, kv_rms_rowsum)

variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [RMS_TILE, 1])
inv_rms = pl.recip(pl.sqrt(variance))
for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2):
kv_norm_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
gamma = norm_w_2d[:, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE]
normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma)
normed_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] = normed_chunk

kv_rope = normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM]
even_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P0101)
odd_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P1010)
rope_even = pl.sub(pl.mul(even_tile, cos_b), pl.mul(odd_tile, sin_b))
rope_odd = pl.add(pl.mul(even_tile, sin_b), pl.mul(odd_tile, cos_b))
idx_target = pl.full([RMS_TILE, ROPE_HEAD_DIM // 2], dtype=pl.INT32, value=0)
even_idx_full = pl.col_expand(idx_target, even_idx)
odd_idx_full = pl.col_expand(idx_target, odd_idx)
rope_buf = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=0.0)
rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=even_idx_full, src=rope_even)
rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=odd_idx_full, src=rope_odd)
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_buf
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same divisibility concern: b_dim // RMS_TILE may skip trailing batches.

Line 182 iterates b_dim // RMS_TILE times. If b_dim is not divisible by RMS_TILE (8), the remaining batches are not processed, leading to incorrect or missing RMSNorm + RoPE output for those batches.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_compressor_ratio128.py` around lines 177 - 213, The
loop over batches uses b_dim // RMS_TILE which drops a trailing partial batch;
change the batching to use a ceiling division (e.g. iterate over math.ceil(b_dim
/ RMS_TILE)) and compute end = min(batch_base + RMS_TILE, b_dim) and curr_tile =
end - batch_base, then replace fixed RMS_TILE uses inside the loop with
curr_tile for slices and buffer shapes (affecting normed_kv, cos_b, sin_b,
partial_sq, idx_target, rope_buf, pooled_kv gathers, and any pl.full/pl.reshape
sizes) so the last partial batch is correctly processed; keep all logic (RMSNorm
and RoPE steps) the same but keyed to curr_tile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant