[TRTLLM-35237][feat] Add cute dsl FP4 paged MQA logits decode kernel#13929
[TRTLLM-35237][feat] Add cute dsl FP4 paged MQA logits decode kernel#13929limin2021 wants to merge 42 commits intoNVIDIA:mainfrom
Conversation
…M100 Replace DeepGEMM-based indexer logits with a CuTE DSL kernel on SM100+, gated by `use_cute_dsl_logits` config flag. Includes kernel implementation, PyTorch custom op registration, config plumbing, and unit tests. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Replace stale development script name with the actual module filename. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Use current_stream() instead of creating a new stream to avoid data races - Remove unnecessary torch.cuda.synchronize() - Add is_sm_100f() check in custom op to fail fast on unsupported GPUs - Rename FP8MQALogitsDGFullKKernel to FP8MQALogitsKernel - Add __all__ to paged_mqa_logits __init__.py - Fix test skip logic to precisely match SM 100/103 only Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…s enabled Broaden the cute_dsl_custom_ops import guard to also trigger when use_cute_dsl_logits is True, so the custom op is registered even when use_cute_dsl_topk is False. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…_logits Align config field name with the op name cute_dsl_fp8_paged_mqa_logits for clarity. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The custom op already validates SM 100/103 via is_sm_100f(). Also replace getattr with direct attribute access to match use_cute_dsl_topk style. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ser-facing strings Remove DeepGEMM and DG-FullK references from module docstring, class docstring, function docstrings, print statements, and argparse descriptions. Inline implementation comments retained as design cross-references. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Replace from_dlpack + mark_compact_shape_dynamic compile pattern with make_fake_compact_tensor + TVM FFI in CuteDSLPagedMQALogitsRunner. This eliminates real tensor data at compile time, removes per-call dlpack wrapping and manual CUDA stream passing at runtime. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pper Add _check_fp8_paged_mqa_logits_dtypes helper that validates all tensor and compute dtypes upfront, collecting all errors into a single ValueError so callers see every mismatch at once rather than one at a time. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Add epi_dtype param to PyTorch ref so epilogue matches kernel precision - Add use_int_data mode for fp16 tests to isolate bugs from FP16 rounding - Tighten tolerances: fp32 atol 5e-3→5e-5, fp16 atol 1e-2→1e-3 with rtol - Use torch.testing.assert_close instead of cosine similarity - Add varlen benchmark support for mixed-length serving workloads - Clean up comments and remove stale parametrize line Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…e map Remove broken run_test/parse_args/__main__ block and its helper functions (make_fused_kv, fused_kv_views, _prepare_inputs, _make_dynamic_dlpacks, _get_or_compile_kernel, dsl_fp8_paged_mqa_logits_dg_fullk) from fp8_paged_mqa_logits.py — all had no callers after the missing paged_mqa_logits_helpers module made them non-functional. Accuracy validation is covered by the unit tests. Reuse existing module-level _TORCH_TO_CUTLASS_DTYPE in CuteDSLPagedMQALogitsRunner instead of duplicating it. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ubtiles Move the FMA unroll granularity check outside the num_epi_subtiles > 1 guard so it also applies when num_epi_subtiles == 1. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Remove is_umma_warp, num_q_stages (local copy), cons_group, num_tma_prod, num_mcast_a/b, and tCtAcc_fake (in __call__) which were assigned but never referenced. Remove unnecessary noqa F401 on pipeline_init_arrive/pipeline_init_wait which are actively used. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nfig Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…arbitrary next_n Fix cudaErrorMisalignedAddress when num_heads=32 with fp16 output by padding sW stage stride to 128-byte boundary for TMA bulk copy. Update qw_per_stage SMEM budget to use padded stride and actual epi_bytes. In dsa.py, skip MTP batch expansion for DSL kernel (supports arbitrary next_n natively) and use correct scheduler_metadata buffer selection. Expand test coverage with num_heads and fix_length parametrization, pure PyTorch reference, and DG expanded benchmark for next_n=3. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Clamp start_q before reading mContextLens to prevent out-of-bounds access when the scheduler assigns batch_size as a sentinel for CTAs with no work. Move torch.manual_seed into _generate_test_data for reproducibility. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…aged MQA logits Decouple compute_block_kv (always 128) from phys_block_kv (physical page size). When phys_block_kv < 128, the kernel issues num_blocks_per_mma TMA copies per compute tile. Removes the tokens_per_block=128 constraint for the DSL kernel path. Signed-off-by: Mindy Li <lmin@nvidia.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…eduling Hoist shuffle_sync calls before producer_acquire in both TMA warpgroups. The previous ordering placed shuffle after barrier acquire, causing a dependency chain that regressed next_n=3 by 5-7%. SASS now matches baseline exactly. Signed-off-by: Mindy Li <lmin@nvidia.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <lmin@nvidia.com> Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…and test Drop dead code flagged in review: - compute_schedule_metadata + cdiv (and now-unused torch import) from the kernel module — never called. - skip_if_unsupported decorator + has_deep_gemm helper + IS_CUTLASS_DSL_AVAILABLE import from the test — defined but never applied. SM100/103 gate already ensures DeepGEMM and CuTe DSL availability. Also refresh the test docstring: the reference is pure PyTorch, not DeepGEMM. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The DSL variant of test_indexer_decode_with_paged_kv_cache calls torch.ops.trtllm.cute_dsl_fp8_paged_mqa_logits, which raises ValueError on SM != 100/103. Mark only the "dsl" parameter with a skipif so the "deepgemm" backend still runs on Hopper. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…xer_gemm Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> # Conflicts: # tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
…xer_gemm Resolve conflicts in DSA paged-MQA-logits dispatch and tests after DeepGEMM submodule bump (4ff3f54d -> c491439e via PR NVIDIA#13340 / DG NVIDIA#304): - dsa.py: take upstream's scheduler_metadata_buffer / _full_next_n selection (mtp3 buffer removed); add DSL early-branch using the existing scheduler_metadata_buffer (built with (num_gen, 1) shape, num_atoms=1, matching DSL's 1-atom-per-q design) and the 1D kv_lens_cuda_runtime slice for context_lens. - dsa.py: introduce module-level _DG_SCHEDULE_BLOCK_KV = 64, used by all 6 get_paged_mqa_logits_metadata calls (3 in on_update_kv_lens(), 3 in Indexer.prepare()) instead of cache tokens_per_block. Decouples schedule SPLIT_KV from cache page size and side-steps a SM100 + block_kv=32 latent regression in DG commit 7f2a703 (NVIDIA#304). - test_dsa_indexer.py: take upstream's scheduler buffer selection; DSL test branch reads scheduler_metadata_buffer + 1D kv_lens. - test_cute_dsl_fp8_paged_mqa_logits.py: 4 metadata calls now pass 2D context_lens via .unsqueeze(-1) and DG_METADATA_BLOCK_KV=64; DG bench drops cluster(2,1,1) for next_n=4 (SM100 always uses num_kv_multicast=1) and passes 2D context_lens to fp8_paged_mqa_logits. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
… MQA next_n coverage
test_dsa_indexer: separate scheduler_metadata_buffer for DSL backend (kNumNextNAtoms=1)
to avoid the next_n>1 alias used by DeepGEMM. Fixes test_indexer_decode_with_paged_kv_cache
across {deepgemm, dsl} x {(4,1),(2,2),(4,3),(4,4)}.
test_cute_dsl_fp8_paged_mqa_logits: extend multi_block next_n coverage to {1,2,3,4};
drop now-obsolete next_n==3 expansion in bench (newer DeepGEMM supports it natively);
update --block_kv default from 128 to 64 to match the new {32,64} assertion in
fp8_paged_mqa_logits.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ress) Snapshot of FP4 MXFP4 block-scaled paged_mqa_logits kernel based on the FP8 kernel: 5 TMA atoms (KV + B + W + SF_KV + SF_Q), per-WG SFA TMEM region, in-place SMEM transpose for UTCCP, per-k_block SFA/SFB slice for cute.gemm. Stage 0 still hangs at tmem.wait_for_alloc bar.sync 1, 320; committing state for further debugging. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ble next_n=3 Three independent fixes to the CuTe DSL FP4 paged MQA logits kernel and its host wrapper: 1. SF Q TMA barrier tx_count was undersized. The TMA descriptor for SF Q is built with tile = N_padded (from sf_q_smem_layout_staged), and the host wrapper pads GMEM to N_padded bytes. So TMA actually fetches N_padded * 4 bytes, not N * 4. Setting the barrier transaction count to N * 4 caused the consumer (UMMA warp) to start the in-place SMEM transpose + UTCCP before TMA finished writing the trailing pad bytes — a race that produced ~40% flaky hangs for next_n=1 (where N=64 < N_padded=128). Fix: tx_count = N_padded * 4. Verified with 10/10 hang-repro runs all PASS in ~3.7s each. 2. Enable next_n=3 by skipping the per-fragment power-of-2 round. utils.get_num_tmem_alloc_cols(..., rounding=True) rounded acc cols 192 -> 256 for next_n=3, pushing raw_total to 544 (>512 cap). The helper-level round is for bookkeeping only; the actual hardware alloc happens at the kernel-level tmem.allocate(num_tmem_alloc_cols_total) which has its own power-of-2 round on the total. Per-fragment counts only need 32-byte alignment (192 = 32*6 OK). Pass rounding=False. 3. Drop redundant .contiguous() in host wrapper to fix B>1 stride mismatch. The compile-time fake tensors use stride_order=(1, 0, 2) (Q) and (0, 1) (sf_q, weights), making the innermost stride 1 and outer strides B-independent. The runtime wrapper's .permute(1, 2, 0).contiguous() repacked memory into a contiguous layout where B becomes innermost, producing strides like (half_D*B, B, 1) — mismatching the fake's (half_D, 1, N*half_D). Removing .contiguous() keeps the permute as a view with the expected B-independent strides. Same change for sf_q and weights. SF Q padding restructured to pad in the (B, N) layout first (preserves transpose-then-view stride pattern). Mirrors the FP8 wrapper which already does it correctly. Result: full sweep over (batch_size, next_n, ctx) goes from 7 PASS / 20 FAIL to 17 PASS / 10 FAIL; all remaining failures are numerical mismatches at large ctx + B>1 + next_n>=2 (separate pre-existing bugs, to be investigated separately). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The DSL FP4 kernel splits DeepGEMM's single UMMA warp into two
(umma_warp_0 and umma_warp_1). Both warps' MMAs read the same SFB
(Q SF) TMEM region, but only warp 0 writes it on q_idx transitions.
The original code relied on q_pipeline.consumer_wait for cross-warp
visibility, which only orders TMA->SMEM and not warp-0->warp-1 TMEM
writes. Two distinct hazards followed:
1. Forward visibility. umma_warp_1 could fire its MMA before
umma_warp_0's async s2t to TMEM SFB landed, reading stale SFB
from the prior batch. Manifested as deterministic numerical
mismatches in B>1 cases at large ctx (e.g. ctx=32768 + B=4
always failed).
2. Backward overwrite. SFB TMEM is a single region (no staging).
Once the forward fix synced warp 1 against the new SFB, warp 0
could still race ahead to the next q transition and overwrite
SFB while warp 1's previous-batch MMA was still reading it.
Manifested as flaky B=16 failures in the extended sweep
(5/2/2 across three runs).
Fix: add NamedBarrier(id=2, 64 threads) used at two points per
umma warp loop iteration:
- After umma_warp_0's SFB s2t (with fence_view_async_tmem_store)
and before umma_warp_1's MMA on q transition: forward visibility.
- At the end of every iteration in both umma warps after
producer_commit: lock-steps the two warps so warp 0's next-iter
SFB write cannot precede warp 1's current-iter MMA commit.
Also expand the test sweep to exercise paths previously gated off:
- avg_ctx adds 8192 and 16384 (was {256, 4096, 32768})
- num_epi_subtiles enables 2 and 4 (was 1 only)
- fix_length covers both True and False (was True only)
Result: 270/270 pass x 3 consecutive runs in the main test;
90/90 pass in the multi-block test (phys_block_kv in {32, 64}).
Per-tile barrier overhead is negligible vs MMA + TMEM ops
(~12s -> ~12s in the 90-case sweep, ~35s for the 270-case sweep).
DeepGEMM avoids this entirely by issuing both groups' UMMAs from
a single UMMA warp; sm100_fp4_paged_mqa_logits.cuh:337-365 has the
'for i in kNumMathWarpGroups' inner loop in one warp.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add phys_block_kv parametrize ([32, 64, 128]) to the main test so the paged multi-block TMA path (phys_block_kv < compute tile = 128, NUM_BLOCKS_PER_MMA > 1) is exercised alongside the single-block path in one place. Remove the now-redundant test_cute_dsl_fp4_paged_mqa_logits_multi_block function — its only unique coverage was batch_size=32, which is scheduler stress rather than a new code path; the main test's [1, 4, 16] already exercises the same kernel logic. Also narrow the active epi_dtype/output_dtype set to only (fp32, bf16) for now to keep the sweep size manageable; the other four dtype combinations are list-commented and trivially re-enabled. Tested: 810/810 PASS in ~102s (5 ctx x 3 phys_block_kv x 1 dtype x 3 subtile x 3 next_n x 3 batch_size x 2 fix_length). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Match DeepGEMM by passing real N (= next_n * num_heads) as the SF Q TMA descriptor tile shape instead of N_padded, so the host wrapper no longer needs to zero-pad sf_q from N to N_padded on every kernel call. This eliminates a per-launch torch.zeros + slice copy and aligns the DSL kernel's SF Q layout with DeepGEMM's reference (sm100_fp4_paged_mqa_logits.cuh:202, tma::copy<kRealNumSFQAtom, 1, 0>). The DSL TMA helper (cpasync.make_tiled_tma_atom) requires the smem_layout passed in to be cosize-symmetric with cta_tiler -- it rejects "SMEM tile > TMA tile" outright. Trick: build the TMA atom with a smaller (N, num_q_stages) smem layout that matches cta_tiler = (N,), but allocate the actual SMEM at (N_padded, num_q_stages) for UTCCP's 128-aligned atom requirement. The two layouts share the same SMEM iterator via a logical view (sSF_Q_for_tma) used only at tma_partition; UTCCP and downstream paths still see the full N_padded SMEM region. Positions [N, N_padded) in each stage are left as garbage post-launch -- never read because UMMA_N = N and the math epilogue only writes acc cols [0, N). Same invariant as DeepGEMM's "kNumSFQAtom-padded SMEM but kRealNumSFQAtom-fetching TMA" pattern. Also barrier tx_count drops from N_padded * 4 bytes to N * 4 bytes to match the actual TMA transfer. Test sweep updated: enable (fp32, bf16) Stage 2 cast path; the other four dtype combos remain list-commented. Stage 1 packed FMA paths (bf16/bf16, fp16/fp16) have a separate pre-existing accuracy issue unrelated to SF Q layout (looks like ELEM_TOL atol set tighter than the actual 64-head packed-FMA accumulation precision); to be investigated separately. Tested: 810/810 PASS in ~106s (5 ctx x 3 phys_block_kv x 3 num_epi_subtiles x 3 next_n x 3 batch_size x 2 fix_length x 1 dtype). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two small cleanups in the kernel constructor:
1. Replace the manual `(32, 64, 128, 256, 512)` tuple lookup for
power-of-two TMEM rounding with the same one-line formula used
inside utils.get_num_tmem_alloc_cols(rounding=True):
self.num_tmem_alloc_cols_total =
max(1 << math.ceil(math.log2(raw_total)), 32)
Functionally identical (raw_total <= 512 always for next_n <= 3),
but reads as the actual rounding semantics and stops needing
maintenance if the upper bound ever changes. The DSL helper
itself can't be called here because we don't yet have a tmem
tensor handle at this point; we only have the raw column count.
2. Replace the now-stale TMA SF Q comment that still talked about
"host pad while host pad still exists" (host pad was removed in
the previous commit) with a concise statement of the invariant:
SMEM positions [N, N_padded) are garbage, propagate through UTCCP
into TMEM SFB cols >= N, but are never read by MMA (UMMA_N = N)
nor written by the epilogue (acc cols [0, N) only).
No functional change. Smoke test on next_n in {1, 3} passes.
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Previous bound `64 if next_n == 1 else 52` was overly conservative: SASS spill check (cuobjdump --dump-sass | grep LDL/STL) under the 240-reg math warpgroup shows next_n=1,2 fit 64 weights/slot with zero spill; only next_n=3 spills above 56. Update to `56 if next_n == 3 else 64`. next_n=2 now caches 64/64 heads in registers (was 52/64), eliminating the SMEM fallback loop entirely and reducing LDS.128 traffic in the epilogue. next_n=3 increases from 52 to 56 (max safe value, still a multiple of 4 for the packed-FMA h_g..h_g+3 layout). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
context_lens must be 2D and block_kv must be 64 (SPLIT_KV alignment with the compute kernel's hardcoded 256, see DSA_DG_C491439E_MIGRATION_NOTES.md). Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…xer_gemm_fp4_post_merge Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> # Conflicts: # tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py # tensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/__init__.py
📝 WalkthroughWalkthroughThis PR adds two PyTorch custom op implementations of paged MQA logits computation for Blackwell SM100: one for FP8-quantized queries and one for FP4-packed queries with per-token scaling. Each variant includes dtype validation, TVM-FFI kernel compilation with caching, and fake tensor support. A comprehensive test module validates the FP4 variant against a pure PyTorch reference with quantization utilities and numerical tolerances. ChangesCustom Ops & Kernel Module
FP4 Test Suite
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py (2)
441-479: ⚡ Quick winGate verbose accuracy probe logs to avoid CI log bloat.
These prints run on most passing cases and can heavily spam logs with this parameter sweep. Consider printing only when an error threshold is exceeded.
Proposed guard
- if elem_abs.numel() > 0: + if elem_abs.numel() > 0 and elem_abs.max().item() > max(5 * atol, 1.0): kernel_valid = logits_clean[valid] ref_valid = ref_clean[valid] print(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py` around lines 441 - 479, The probe printouts (those printing logits_clean/ref_clean, max/mean stats, and the large_err buckets) should be gated so they only run when the error threshold is exceeded to avoid CI log spam; change the logic around diff_abs, large_err and the subsequent prints so that you compute diff_abs and large_err as now but only emit all the detailed prints (the avg_ctx/epi/out prints, kernel/ref snippets, and bucket breakdowns using mod32_buckets/mod8_buckets) when len(large_err) > 0 or len(large_err) exceeds a configurable threshold (e.g., >0 or >N), or when a test-verbose flag/env var is set; keep the minimal failing assert messages outside the gate and reference the existing identifiers logits_clean, ref_clean, diff_abs, large_err, mod32_buckets, and mod8_buckets when moving the prints into that conditional.
54-67: ⚡ Quick winAdd explicit return type annotations to helper and test functions.
A few defs still omit return types (
ceil_to_ue8m0,pack_ue8m0_to_int,kv_cache_cast_to_fp4,calc_diff,_ref_paged_mqa_logits, and the test signature). Please annotate them for consistency with repo typing rules.Proposed typing cleanup
-def ceil_to_ue8m0(x: torch.Tensor): +def ceil_to_ue8m0(x: torch.Tensor) -> torch.Tensor: @@ -def pack_ue8m0_to_int(x: torch.Tensor): +def pack_ue8m0_to_int(x: torch.Tensor) -> torch.Tensor: @@ -def kv_cache_cast_to_fp4(x: torch.Tensor): +def kv_cache_cast_to_fp4(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -def calc_diff(x: torch.Tensor, y: torch.Tensor): +def calc_diff(x: torch.Tensor, y: torch.Tensor) -> float: @@ -def _ref_paged_mqa_logits( +def _ref_paged_mqa_logits( @@ -): +) -> torch.Tensor:As per coding guidelines, "Python code should use type annotations for all function arguments and return types; return type
Noneif function does not return".Also applies to: 142-194, 280-290
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py` around lines 54 - 67, Add explicit return type annotations to the helper and test functions mentioned: annotate ceil_to_ue8m0, unpack_ue8m0_from_int, pack_ue8m0_to_int, kv_cache_cast_to_fp4, calc_diff, and _ref_paged_mqa_logits with appropriate torch.Tensor return types (e.g., -> torch.Tensor) and annotate the test function signature to return None (-> None) to satisfy repository typing rules; update any imports if needed to reference torch in type hints and keep signatures consistent with existing parameter types.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 6984-7016: In cute_dsl_fp4_paged_mqa_logits validate the
num_epi_subtiles parameter at the API boundary (inside the function before
calling CuteDSLFP4PagedMQALogitsRunner.forward) and raise a clear ValueError if
it is not one of the allowed values {1, 2, 4}; include the offending value in
the error message to aid debugging and prevent invalid integers from being
forwarded into kernel compilation/cache lookup.
- Around line 6751-6780: The helper _check_fp4_paged_mqa_logits_dtypes currently
only validates dtypes; add explicit shape/rank checks to fail fast before the
later reshapes in CuteDSLFP4PagedMQALogitsRunner.forward: verify q.ndim == 4,
sf_q.ndim == 3 and sf_q.shape == q.shape[:3], kv_fused.ndim >= 1 and
kv_fused.shape[-1] == q.shape[-1] + 4, weights.ndim == 2 and weights.shape ==
(q.shape[0] * q.shape[1], q.shape[2]) or the intended (B * next_n, H) pattern,
context_lens.shape == (q.shape[0],) and context_lens.dim() == 1, and
block_table.shape[0] == q.shape[0]; append clear error messages to errs and
raise as done for dtype issues so mis-shaped but same-numel tensors fail early.
In
`@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`:
- Around line 203-204: In _ref_paged_mqa_logits, the tuple unpacking binds
num_heads and num_block but they are unused; change those bindings to
underscore-prefixed names (e.g., _num_heads and _num_block) or use plain
underscores to avoid unused-variable lint warnings; update the unpacking lines
that deconstruct q.size() and kv_cache.size() accordingly so only needed names
(batch_size, next_n, dim, block_size, _) remain as meaningful identifiers.
---
Nitpick comments:
In
`@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`:
- Around line 441-479: The probe printouts (those printing
logits_clean/ref_clean, max/mean stats, and the large_err buckets) should be
gated so they only run when the error threshold is exceeded to avoid CI log
spam; change the logic around diff_abs, large_err and the subsequent prints so
that you compute diff_abs and large_err as now but only emit all the detailed
prints (the avg_ctx/epi/out prints, kernel/ref snippets, and bucket breakdowns
using mod32_buckets/mod8_buckets) when len(large_err) > 0 or len(large_err)
exceeds a configurable threshold (e.g., >0 or >N), or when a test-verbose
flag/env var is set; keep the minimal failing assert messages outside the gate
and reference the existing identifiers logits_clean, ref_clean, diff_abs,
large_err, mod32_buckets, and mod8_buckets when moving the prints into that
conditional.
- Around line 54-67: Add explicit return type annotations to the helper and test
functions mentioned: annotate ceil_to_ue8m0, unpack_ue8m0_from_int,
pack_ue8m0_to_int, kv_cache_cast_to_fp4, calc_diff, and _ref_paged_mqa_logits
with appropriate torch.Tensor return types (e.g., -> torch.Tensor) and annotate
the test function signature to return None (-> None) to satisfy repository
typing rules; update any imports if needed to reference torch in type hints and
keep signatures consistent with existing parameter types.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2fbf6591-7936-4c31-b269-c0e62650d129
📒 Files selected for processing (4)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/__init__.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/fp4_paged_mqa_logits.pytests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py
| def _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights, | ||
| context_lens, block_table, | ||
| schedule_meta, epi_dtype, | ||
| output_dtype): | ||
| errs = [] | ||
| if q.dtype != torch.uint8: | ||
| errs.append(f"q must be uint8 (FP4 packed), got {q.dtype}") | ||
| if sf_q.dtype != torch.int32: | ||
| errs.append(f"sf_q must be int32, got {sf_q.dtype}") | ||
| if kv_fused.dtype != torch.uint8: | ||
| errs.append(f"kv_fused must be uint8, got {kv_fused.dtype}") | ||
| if weights.dtype != torch.float32: | ||
| errs.append(f"weights must be float32, got {weights.dtype}") | ||
| if context_lens.dim() != 1: | ||
| errs.append(f"context_lens must be 1D, got {context_lens.dim()}D") | ||
| if context_lens.dtype != torch.int32: | ||
| errs.append(f"context_lens must be int32, got {context_lens.dtype}") | ||
| if block_table.dtype != torch.int32: | ||
| errs.append(f"block_table must be int32, got {block_table.dtype}") | ||
| if schedule_meta.dtype != torch.int32: | ||
| errs.append( | ||
| f"schedule_meta must be int32, got {schedule_meta.dtype}") | ||
| for name, dt in [("epi_dtype", epi_dtype), | ||
| ("output_dtype", output_dtype)]: | ||
| if dt not in (torch.float16, torch.bfloat16, torch.float32): | ||
| errs.append( | ||
| f"{name} must be float16, bfloat16, or float32, got {dt}") | ||
| if errs: | ||
| raise ValueError("FP4 Paged MQA Logits dtype errors:\n " + | ||
| "\n ".join(errs)) |
There was a problem hiding this comment.
Validate FP4 paged-MQA shapes before the reshapes.
CuteDSLFP4PagedMQALogitsRunner.forward() later reinterprets q, sf_q, and weights as [B, N, ...] purely via reshape(). Right now this helper only checks dtypes, so a same-numel but differently ordered tensor can pass validation and silently scramble the token/head mapping instead of failing fast. Please validate the expected ranks and cross-tensor shapes here (q.ndim == 4, sf_q.shape == q.shape[:3], weights.shape == (B * next_n, H), context_lens.shape == (B,), block_table.shape[0] == B, kv_fused.shape[-1] == q.shape[-1] + 4, etc.).
Suggested guard pattern
def _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype):
errs = []
@@
if schedule_meta.dtype != torch.int32:
errs.append(
f"schedule_meta must be int32, got {schedule_meta.dtype}")
+ if q.dim() != 4:
+ errs.append(f"q must be 4D [B, next_n, H, D//2], got {q.dim()}D")
+ else:
+ batch_size, next_n, num_heads, half_head_dim = q.shape
+ if sf_q.shape != (batch_size, next_n, num_heads):
+ errs.append(
+ f"sf_q must have shape {(batch_size, next_n, num_heads)}, got {tuple(sf_q.shape)}"
+ )
+ if weights.shape != (batch_size * next_n, num_heads):
+ errs.append(
+ f"weights must have shape {(batch_size * next_n, num_heads)}, got {tuple(weights.shape)}"
+ )
+ if context_lens.shape != (batch_size,):
+ errs.append(
+ f"context_lens must have shape {(batch_size,)}, got {tuple(context_lens.shape)}"
+ )
+ if block_table.dim() != 2 or block_table.shape[0] != batch_size:
+ errs.append(
+ f"block_table must be 2D with batch dimension {batch_size}, got {tuple(block_table.shape)}"
+ )
+ if kv_fused.dim() != 4 or kv_fused.shape[2] != 1 or kv_fused.shape[3] != half_head_dim + 4:
+ errs.append(
+ "kv_fused must have shape [num_blocks, phys_block_kv, 1, D//2 + 4]"
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 6751 -
6780, The helper _check_fp4_paged_mqa_logits_dtypes currently only validates
dtypes; add explicit shape/rank checks to fail fast before the later reshapes in
CuteDSLFP4PagedMQALogitsRunner.forward: verify q.ndim == 4, sf_q.ndim == 3 and
sf_q.shape == q.shape[:3], kv_fused.ndim >= 1 and kv_fused.shape[-1] ==
q.shape[-1] + 4, weights.ndim == 2 and weights.shape == (q.shape[0] *
q.shape[1], q.shape[2]) or the intended (B * next_n, H) pattern,
context_lens.shape == (q.shape[0],) and context_lens.dim() == 1, and
block_table.shape[0] == q.shape[0]; append clear error messages to errs and
raise as done for dtype issues so mis-shaped but same-numel tensors fail early.
| def cute_dsl_fp4_paged_mqa_logits( | ||
| q: torch.Tensor, | ||
| sf_q: torch.Tensor, | ||
| kv_fused: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| context_lens: torch.Tensor, | ||
| block_table: torch.Tensor, | ||
| schedule_meta: torch.Tensor, | ||
| max_context_len: int, | ||
| num_epi_subtiles: int = 1, | ||
| epi_dtype: torch.dtype = torch.float32, | ||
| output_dtype: torch.dtype = torch.float32, | ||
| ) -> torch.Tensor: | ||
| if not is_sm_100f(): | ||
| raise ValueError( | ||
| f"CuteDSL: SM version {get_sm_version()} is not supported. " | ||
| f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.") | ||
| _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights, | ||
| context_lens, block_table, | ||
| schedule_meta, epi_dtype, | ||
| output_dtype) | ||
| return CuteDSLFP4PagedMQALogitsRunner.forward( | ||
| q, | ||
| sf_q, | ||
| kv_fused, | ||
| weights, | ||
| context_lens, | ||
| block_table, | ||
| schedule_meta, | ||
| max_context_len, | ||
| num_epi_subtiles=num_epi_subtiles, | ||
| epi_dtype=epi_dtype, | ||
| output_dtype=output_dtype) |
There was a problem hiding this comment.
Reject unsupported num_epi_subtiles values at the API boundary.
The docstring constrains num_epi_subtiles to 1, 2, or 4, but the custom op currently accepts any integer and forwards it straight into kernel compilation/cache lookup. Invalid values will surface as a much harder-to-diagnose JIT/kernel failure.
Suggested guard
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.")
+ if num_epi_subtiles not in (1, 2, 4):
+ raise ValueError(
+ f"num_epi_subtiles must be one of (1, 2, 4), got {num_epi_subtiles}"
+ )
_check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def cute_dsl_fp4_paged_mqa_logits( | |
| q: torch.Tensor, | |
| sf_q: torch.Tensor, | |
| kv_fused: torch.Tensor, | |
| weights: torch.Tensor, | |
| context_lens: torch.Tensor, | |
| block_table: torch.Tensor, | |
| schedule_meta: torch.Tensor, | |
| max_context_len: int, | |
| num_epi_subtiles: int = 1, | |
| epi_dtype: torch.dtype = torch.float32, | |
| output_dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| if not is_sm_100f(): | |
| raise ValueError( | |
| f"CuteDSL: SM version {get_sm_version()} is not supported. " | |
| f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.") | |
| _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights, | |
| context_lens, block_table, | |
| schedule_meta, epi_dtype, | |
| output_dtype) | |
| return CuteDSLFP4PagedMQALogitsRunner.forward( | |
| q, | |
| sf_q, | |
| kv_fused, | |
| weights, | |
| context_lens, | |
| block_table, | |
| schedule_meta, | |
| max_context_len, | |
| num_epi_subtiles=num_epi_subtiles, | |
| epi_dtype=epi_dtype, | |
| output_dtype=output_dtype) | |
| def cute_dsl_fp4_paged_mqa_logits( | |
| q: torch.Tensor, | |
| sf_q: torch.Tensor, | |
| kv_fused: torch.Tensor, | |
| weights: torch.Tensor, | |
| context_lens: torch.Tensor, | |
| block_table: torch.Tensor, | |
| schedule_meta: torch.Tensor, | |
| max_context_len: int, | |
| num_epi_subtiles: int = 1, | |
| epi_dtype: torch.dtype = torch.float32, | |
| output_dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| if not is_sm_100f(): | |
| raise ValueError( | |
| f"CuteDSL: SM version {get_sm_version()} is not supported. " | |
| f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.") | |
| if num_epi_subtiles not in (1, 2, 4): | |
| raise ValueError( | |
| f"num_epi_subtiles must be one of (1, 2, 4), got {num_epi_subtiles}" | |
| ) | |
| _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights, | |
| context_lens, block_table, | |
| schedule_meta, epi_dtype, | |
| output_dtype) | |
| return CuteDSLFP4PagedMQALogitsRunner.forward( | |
| q, | |
| sf_q, | |
| kv_fused, | |
| weights, | |
| context_lens, | |
| block_table, | |
| schedule_meta, | |
| max_context_len, | |
| num_epi_subtiles=num_epi_subtiles, | |
| epi_dtype=epi_dtype, | |
| output_dtype=output_dtype) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 6984 -
7016, In cute_dsl_fp4_paged_mqa_logits validate the num_epi_subtiles parameter
at the API boundary (inside the function before calling
CuteDSLFP4PagedMQALogitsRunner.forward) and raise a clear ValueError if it is
not one of the allowed values {1, 2, 4}; include the offending value in the
error message to aid debugging and prevent invalid integers from being forwarded
into kernel compilation/cache lookup.
| batch_size, next_n, num_heads, dim = q.size() | ||
| num_block, block_size, _, dim = kv_cache.size() |
There was a problem hiding this comment.
Remove unused unpacked variables in _ref_paged_mqa_logits.
num_heads and num_block are unpacked but never used (RUF059). Rename to underscored bindings (or drop them) to keep lint clean.
Proposed fix
- batch_size, next_n, num_heads, dim = q.size()
- num_block, block_size, _, dim = kv_cache.size()
+ batch_size, next_n, _num_heads, dim = q.size()
+ _num_block, block_size, _, dim = kv_cache.size()🧰 Tools
🪛 Ruff (0.15.12)
[warning] 203-203: Unpacked variable num_heads is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 204-204: Unpacked variable num_block is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`
around lines 203 - 204, In _ref_paged_mqa_logits, the tuple unpacking binds
num_heads and num_block but they are unused; change those bindings to
underscore-prefixed names (e.g., _num_heads and _num_block) or use plain
underscores to avoid unused-variable lint warnings; update the unpacking lines
that deconstruct q.size() and kv_cache.size() accordingly so only needed names
(batch_size, next_n, dim, block_size, _) remain as meaningful identifiers.
…/scripts/ Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…n correctness check Move the remaining standalone-driver pieces from fp4_paged_mqa_logits.py to run_fp4.py (FP4 quant helpers, _compute_schedule_metadata, the compile cache + fake-tensor _compile_fp4_kernel, host wrapper fp4_paged_mqa_logits, and the pure-torch _ref_paged_mqa_logits). The kernel file now contains only the FP4MQALogitsKernel class. Replace the runner's cosine-only PASS criterion with the element-wise atol+rtol check used by the unit test (ELEM_TOL table). Cosine diff is kept as a supplementary global metric. Output now reports max_abs, mean_abs, atol, rtol alongside the diff. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…scripts/ Mirror run_fp4.py with FP8-specific data prep / reference inlined from the unit test (test_cute_dsl_fp8_paged_mqa_logits.py) and the compile + dispatch path inlined from cute_dsl_custom_ops.py (CuteDSLPagedMQALogitsRunner). Schedule metadata reuses the same pure Python implementation as run_fp4.py — both kernels share compute_block_kv=128 + NUM_MATH_WG=2 → SPLIT_KV=256. Correctness check matches the unit test: element-wise atol+rtol from output_dtype (fp32: 5e-5/1e-5, fp16: 1e-3/1e-3) plus cosine diff as a supplementary global metric. The runner prints diagnostic stats (max_abs, mean_abs, atol, rtol, diff) and PASS/FAIL without raising, matching the FP4 runner's "show all results" semantics. Not wired into CI — same convention as the other launch scripts under tests/scripts/cute_dsl_kernels/. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.