Refactor(dsv4): full dynamic-shape for compressor_ratio128 #410#412
Refactor(dsv4): full dynamic-shape for compressor_ratio128 #410#412bumble0918 wants to merge 1 commit into
Conversation
bumble0918
commented
May 28, 2026
- Replace static dims (B, S, BLOCK_NUM, block-table sizes) with pl.dynamic(); derive all runtime sizes via pl.tensor.dim()
- Convert all compute scopes (kv_score_proj, state_scatter, softmax_pool, rmsnorm_rope, kv_finalize) from pl.spmd(static B) to pl.parallel(dynamic b_dim) + inner pl.spmd/pl.at
- Add bind_dynamic() for all new DynVars in compressor_test
…sys#410 - Replace static dims (B, S, BLOCK_NUM, block-table sizes) with pl.dynamic(); derive all runtime sizes via pl.tensor.dim() - Convert all compute scopes (kv_score_proj, state_scatter, softmax_pool, rmsnorm_rope, kv_finalize) from pl.spmd(static B) to pl.parallel(dynamic b_dim) + inner pl.spmd/pl.at - Add bind_dynamic() for all new DynVars in compressor_test
📝 WalkthroughWalkthroughThe ChangesDeepSeek Compressor Dynamic Shape Migration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic shape support to the DeepSeek v4 decode compressor by replacing static dimensions with dynamic variables and binding them in the test function. While the transition to dynamic shapes is a great improvement, the current implementation introduces several critical bugs where integer division in parallel loops (such as bs // B_TILE and b_dim // RMS_TILE) will truncate to zero for small batch sizes or sequence lengths, causing essential operations like projection, RMSNorm, RoPE, and KV finalization to be skipped entirely. Additionally, using pl.range with the dynamic variable b_dim may lead to compilation or optimization issues and should be replaced with a dynamic-friendly construct like pl.parallel.
| kv_proj_scratch = pl.create_tensor([bs, OUT_DIM], dtype=pl.FP32) | ||
| score_proj_scratch = pl.create_tensor([bs, OUT_DIM], dtype=pl.FP32) | ||
|
|
||
| for row_idx in pl.parallel(0, bs // B_TILE): |
There was a problem hiding this comment.
The loop pl.parallel(0, bs // B_TILE) uses integer division which truncates the loop range. If bs (which is b_dim * s_dim) is less than B_TILE (64), bs // B_TILE evaluates to 0, and the projection loop is completely skipped. For example, with s_dim = 2 (MTP) and a dynamic batch size b_dim < 32, bs will be less than 64, resulting in no projection being computed at all. To support full dynamic shapes, you must handle cases where bs is not a multiple of B_TILE, typically by using ceiling division and adding boundary checks inside the loop to avoid out-of-bounds accesses on x_flat.
| # batches per spmd block so all vec col-vectors hit ptoas's 32B-aligned | ||
| # row stride without per-batch row padding. | ||
| for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="rmsnorm_rope"): | ||
| for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE): |
There was a problem hiding this comment.
The loop pl.parallel(0, b_dim // RMS_TILE) assumes that the dynamic batch dimension b_dim is always a multiple of RMS_TILE (which is 8). If b_dim is not a multiple of RMS_TILE (for example, if b_dim < 8 during dynamic batching), the integer division b_dim // RMS_TILE will truncate to 0, causing the remaining batches to be completely skipped during RMSNorm, RoPE, and KV finalization. To support truly arbitrary dynamic shapes, you should either assert/document that b_dim must be a multiple of RMS_TILE, or use ceiling division pl.ceil_div(b_dim, RMS_TILE) and add boundary guards (e.g., if global_c_idx < b_dim:) inside the loops to prevent out-of-bounds accesses when slicing or reading tensors.
| # carries (kv_flat / cmp_kv_cache_flat / compress_state_flat) on one task, | ||
| # which hits pypto #1573 (orchestration phi cross-assignment). | ||
| for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="kv_finalize"): | ||
| for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE): |
There was a problem hiding this comment.
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_scatter_pre"): | ||
| for global_c_idx in pl.range(B): | ||
| for global_c_idx in pl.range(b_dim): |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/decode_compressor_ratio128.py`:
- Around line 177-213: The loop over batches uses b_dim // RMS_TILE which drops
a trailing partial batch; change the batching to use a ceiling division (e.g.
iterate over math.ceil(b_dim / RMS_TILE)) and compute end = min(batch_base +
RMS_TILE, b_dim) and curr_tile = end - batch_base, then replace fixed RMS_TILE
uses inside the loop with curr_tile for slices and buffer shapes (affecting
normed_kv, cos_b, sin_b, partial_sq, idx_target, rope_buf, pooled_kv gathers,
and any pl.full/pl.reshape sizes) so the last partial batch is correctly
processed; keep all logic (RMSNorm and RoPE steps) the same but keyed to
curr_tile.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 715fec11-05b9-4d84-9ab6-8294474c6645
📒 Files selected for processing (1)
models/deepseek/v4/decode_compressor_ratio128.py
| normed_kv = pl.create_tensor([b_dim, HEAD_DIM], dtype=pl.FP32) | ||
|
|
||
| # Fused RMSNorm + gather/scatter-based RoPE, processing RMS_TILE real | ||
| # batches per spmd block so all vec col-vectors hit ptoas's 32B-aligned | ||
| # row stride without per-batch row padding. | ||
| for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="rmsnorm_rope"): | ||
| for batch_base_idx in pl.parallel(0, b_dim // RMS_TILE): | ||
| batch_base = batch_base_idx * RMS_TILE | ||
| cos_b = cos[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2] | ||
| sin_b = sin[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2] | ||
| partial_sq = pl.full([1, RMS_TILE], dtype=pl.FP32, value=0.0) | ||
| for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): | ||
| kv_rms_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| kv_rms_sq = pl.mul(kv_rms_chunk, kv_rms_chunk) | ||
| kv_rms_rowsum = pl.reshape(pl.row_sum(kv_rms_sq), [1, RMS_TILE]) | ||
| partial_sq = pl.add(partial_sq, kv_rms_rowsum) | ||
|
|
||
| variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [RMS_TILE, 1]) | ||
| inv_rms = pl.recip(pl.sqrt(variance)) | ||
| for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): | ||
| kv_norm_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| gamma = norm_w_2d[:, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma) | ||
| normed_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] = normed_chunk | ||
|
|
||
| kv_rope = normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] | ||
| even_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P0101) | ||
| odd_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P1010) | ||
| rope_even = pl.sub(pl.mul(even_tile, cos_b), pl.mul(odd_tile, sin_b)) | ||
| rope_odd = pl.add(pl.mul(even_tile, sin_b), pl.mul(odd_tile, cos_b)) | ||
| idx_target = pl.full([RMS_TILE, ROPE_HEAD_DIM // 2], dtype=pl.INT32, value=0) | ||
| even_idx_full = pl.col_expand(idx_target, even_idx) | ||
| odd_idx_full = pl.col_expand(idx_target, odd_idx) | ||
| rope_buf = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=0.0) | ||
| rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=even_idx_full, src=rope_even) | ||
| rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=odd_idx_full, src=rope_odd) | ||
| normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_buf | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm_rope"): | ||
| cos_b = cos[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2] | ||
| sin_b = sin[batch_base : batch_base + RMS_TILE, 0 : ROPE_HEAD_DIM // 2] | ||
| partial_sq = pl.full([1, RMS_TILE], dtype=pl.FP32, value=0.0) | ||
| for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): | ||
| kv_rms_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| kv_rms_sq = pl.mul(kv_rms_chunk, kv_rms_chunk) | ||
| kv_rms_rowsum = pl.reshape(pl.row_sum(kv_rms_sq), [1, RMS_TILE]) | ||
| partial_sq = pl.add(partial_sq, kv_rms_rowsum) | ||
|
|
||
| variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [RMS_TILE, 1]) | ||
| inv_rms = pl.recip(pl.sqrt(variance)) | ||
| for rms_kb in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): | ||
| kv_norm_chunk = pooled_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| gamma = norm_w_2d[:, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] | ||
| normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma) | ||
| normed_kv[batch_base : batch_base + RMS_TILE, rms_kb * HEAD_TILE : (rms_kb + 1) * HEAD_TILE] = normed_chunk | ||
|
|
||
| kv_rope = normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] | ||
| even_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P0101) | ||
| odd_tile = pl.gather(kv_rope, mask_pattern=pl.tile.MaskPattern.P1010) | ||
| rope_even = pl.sub(pl.mul(even_tile, cos_b), pl.mul(odd_tile, sin_b)) | ||
| rope_odd = pl.add(pl.mul(even_tile, sin_b), pl.mul(odd_tile, cos_b)) | ||
| idx_target = pl.full([RMS_TILE, ROPE_HEAD_DIM // 2], dtype=pl.INT32, value=0) | ||
| even_idx_full = pl.col_expand(idx_target, even_idx) | ||
| odd_idx_full = pl.col_expand(idx_target, odd_idx) | ||
| rope_buf = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=0.0) | ||
| rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=even_idx_full, src=rope_even) | ||
| rope_buf = pl.tensor.scatter(rope_buf, dim=-1, index=odd_idx_full, src=rope_odd) | ||
| normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_buf |
There was a problem hiding this comment.
Same divisibility concern: b_dim // RMS_TILE may skip trailing batches.
Line 182 iterates b_dim // RMS_TILE times. If b_dim is not divisible by RMS_TILE (8), the remaining batches are not processed, leading to incorrect or missing RMSNorm + RoPE output for those batches.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/decode_compressor_ratio128.py` around lines 177 - 213, The
loop over batches uses b_dim // RMS_TILE which drops a trailing partial batch;
change the batching to use a ceiling division (e.g. iterate over math.ceil(b_dim
/ RMS_TILE)) and compute end = min(batch_base + RMS_TILE, b_dim) and curr_tile =
end - batch_base, then replace fixed RMS_TILE uses inside the loop with
curr_tile for slices and buffer shapes (affecting normed_kv, cos_b, sin_b,
partial_sq, idx_target, rope_buf, pooled_kv gathers, and any pl.full/pl.reshape
sizes) so the last partial batch is correctly processed; keep all logic (RMSNorm
and RoPE steps) the same but keyed to curr_tile.