Skip to content

[TRTLLM-35237][feat] Add cute dsl FP4 paged MQA logits decode kernel#13929

Open
limin2021 wants to merge 42 commits intoNVIDIA:mainfrom
limin2021:add_dsl_indexer_gemm_fp4_post_merge
Open

[TRTLLM-35237][feat] Add cute dsl FP4 paged MQA logits decode kernel#13929
limin2021 wants to merge 42 commits intoNVIDIA:mainfrom
limin2021:add_dsl_indexer_gemm_fp4_post_merge

Conversation

@limin2021
Copy link
Copy Markdown
Collaborator

@limin2021 limin2021 commented May 9, 2026

Summary by CodeRabbit

  • New Features

    • Added optimized paged MQA (Multi-Query Attention) logits kernel implementations for Blackwell SM100 GPUs, enabling efficient execution with both FP8-quantized and FP4-quantized model weights.
  • Tests

    • Added comprehensive test coverage for FP4-quantized paged MQA logits computations with numerical validation and tolerance verification across multiple batch configurations.

 # accuracy ut
  python -m pytest tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py -v
# performance ut
xx
# standardalone runner [dont depend on trtllm env]
python tests/scripts/cute_dsl_kernels/paged_mqa_logits/run_fp4.py

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…M100

Replace DeepGEMM-based indexer logits with a CuTE DSL kernel on SM100+,
gated by `use_cute_dsl_logits` config flag. Includes kernel implementation,
PyTorch custom op registration, config plumbing, and unit tests.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Replace stale development script name with the actual module filename.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Use current_stream() instead of creating a new stream to avoid data races
- Remove unnecessary torch.cuda.synchronize()
- Add is_sm_100f() check in custom op to fail fast on unsupported GPUs
- Rename FP8MQALogitsDGFullKKernel to FP8MQALogitsKernel
- Add __all__ to paged_mqa_logits __init__.py
- Fix test skip logic to precisely match SM 100/103 only

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…s enabled

Broaden the cute_dsl_custom_ops import guard to also trigger when
use_cute_dsl_logits is True, so the custom op is registered even
when use_cute_dsl_topk is False.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…_logits

Align config field name with the op name cute_dsl_fp8_paged_mqa_logits
for clarity.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The custom op already validates SM 100/103 via is_sm_100f(). Also
replace getattr with direct attribute access to match use_cute_dsl_topk
style.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ser-facing strings

Remove DeepGEMM and DG-FullK references from module docstring, class
docstring, function docstrings, print statements, and argparse
descriptions. Inline implementation comments retained as design
cross-references.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Replace from_dlpack + mark_compact_shape_dynamic compile pattern with
make_fake_compact_tensor + TVM FFI in CuteDSLPagedMQALogitsRunner.
This eliminates real tensor data at compile time, removes per-call
dlpack wrapping and manual CUDA stream passing at runtime.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…pper

Add _check_fp8_paged_mqa_logits_dtypes helper that validates all tensor
and compute dtypes upfront, collecting all errors into a single ValueError
so callers see every mismatch at once rather than one at a time.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
- Add epi_dtype param to PyTorch ref so epilogue matches kernel precision
- Add use_int_data mode for fp16 tests to isolate bugs from FP16 rounding
- Tighten tolerances: fp32 atol 5e-3→5e-5, fp16 atol 1e-2→1e-3 with rtol
- Use torch.testing.assert_close instead of cosine similarity
- Add varlen benchmark support for mixed-length serving workloads
- Clean up comments and remove stale parametrize line

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…e map

Remove broken run_test/parse_args/__main__ block and its helper functions
(make_fused_kv, fused_kv_views, _prepare_inputs, _make_dynamic_dlpacks,
_get_or_compile_kernel, dsl_fp8_paged_mqa_logits_dg_fullk) from
fp8_paged_mqa_logits.py — all had no callers after the missing
paged_mqa_logits_helpers module made them non-functional. Accuracy
validation is covered by the unit tests.

Reuse existing module-level _TORCH_TO_CUTLASS_DTYPE in
CuteDSLPagedMQALogitsRunner instead of duplicating it.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ubtiles

Move the FMA unroll granularity check outside the num_epi_subtiles > 1
guard so it also applies when num_epi_subtiles == 1.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Remove is_umma_warp, num_q_stages (local copy), cons_group,
num_tma_prod, num_mcast_a/b, and tCtAcc_fake (in __call__) which
were assigned but never referenced. Remove unnecessary noqa F401
on pipeline_init_arrive/pipeline_init_wait which are actively used.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…nfig

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…arbitrary next_n

Fix cudaErrorMisalignedAddress when num_heads=32 with fp16 output by
padding sW stage stride to 128-byte boundary for TMA bulk copy. Update
qw_per_stage SMEM budget to use padded stride and actual epi_bytes.

In dsa.py, skip MTP batch expansion for DSL kernel (supports arbitrary
next_n natively) and use correct scheduler_metadata buffer selection.

Expand test coverage with num_heads and fix_length parametrization,
pure PyTorch reference, and DG expanded benchmark for next_n=3.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Clamp start_q before reading mContextLens to prevent out-of-bounds
access when the scheduler assigns batch_size as a sentinel for CTAs
with no work. Move torch.manual_seed into _generate_test_data for
reproducibility.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…aged MQA logits

Decouple compute_block_kv (always 128) from phys_block_kv (physical page
size). When phys_block_kv < 128, the kernel issues num_blocks_per_mma
TMA copies per compute tile. Removes the tokens_per_block=128 constraint
for the DSL kernel path.

Signed-off-by: Mindy Li <lmin@nvidia.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…eduling

Hoist shuffle_sync calls before producer_acquire in both TMA warpgroups.
The previous ordering placed shuffle after barrier acquire, causing a
dependency chain that regressed next_n=3 by 5-7%. SASS now matches
baseline exactly.

Signed-off-by: Mindy Li <lmin@nvidia.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <lmin@nvidia.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…and test

Drop dead code flagged in review:
- compute_schedule_metadata + cdiv (and now-unused torch import) from the
  kernel module — never called.
- skip_if_unsupported decorator + has_deep_gemm helper + IS_CUTLASS_DSL_AVAILABLE
  import from the test — defined but never applied. SM100/103 gate already
  ensures DeepGEMM and CuTe DSL availability.

Also refresh the test docstring: the reference is pure PyTorch, not DeepGEMM.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The DSL variant of test_indexer_decode_with_paged_kv_cache calls
torch.ops.trtllm.cute_dsl_fp8_paged_mqa_logits, which raises ValueError
on SM != 100/103. Mark only the "dsl" parameter with a skipif so the
"deepgemm" backend still runs on Hopper.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…xer_gemm

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>

# Conflicts:
#	tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
…xer_gemm

Resolve conflicts in DSA paged-MQA-logits dispatch and tests after
DeepGEMM submodule bump (4ff3f54d -> c491439e via PR NVIDIA#13340 / DG NVIDIA#304):

- dsa.py: take upstream's scheduler_metadata_buffer / _full_next_n
  selection (mtp3 buffer removed); add DSL early-branch using the
  existing scheduler_metadata_buffer (built with (num_gen, 1) shape,
  num_atoms=1, matching DSL's 1-atom-per-q design) and the 1D
  kv_lens_cuda_runtime slice for context_lens.
- dsa.py: introduce module-level _DG_SCHEDULE_BLOCK_KV = 64, used by
  all 6 get_paged_mqa_logits_metadata calls (3 in on_update_kv_lens(),
  3 in Indexer.prepare()) instead of cache tokens_per_block.
  Decouples schedule SPLIT_KV from cache page size and side-steps a
  SM100 + block_kv=32 latent regression in DG commit 7f2a703 (NVIDIA#304).
- test_dsa_indexer.py: take upstream's scheduler buffer selection;
  DSL test branch reads scheduler_metadata_buffer + 1D kv_lens.
- test_cute_dsl_fp8_paged_mqa_logits.py: 4 metadata calls now pass
  2D context_lens via .unsqueeze(-1) and DG_METADATA_BLOCK_KV=64;
  DG bench drops cluster(2,1,1) for next_n=4 (SM100 always uses
  num_kv_multicast=1) and passes 2D context_lens to fp8_paged_mqa_logits.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
… MQA next_n coverage

test_dsa_indexer: separate scheduler_metadata_buffer for DSL backend (kNumNextNAtoms=1)
to avoid the next_n>1 alias used by DeepGEMM. Fixes test_indexer_decode_with_paged_kv_cache
across {deepgemm, dsl} x {(4,1),(2,2),(4,3),(4,4)}.

test_cute_dsl_fp8_paged_mqa_logits: extend multi_block next_n coverage to {1,2,3,4};
drop now-obsolete next_n==3 expansion in bench (newer DeepGEMM supports it natively);
update --block_kv default from 128 to 64 to match the new {32,64} assertion in
fp8_paged_mqa_logits.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ress)

Snapshot of FP4 MXFP4 block-scaled paged_mqa_logits kernel based on
the FP8 kernel: 5 TMA atoms (KV + B + W + SF_KV + SF_Q), per-WG SFA
TMEM region, in-place SMEM transpose for UTCCP, per-k_block SFA/SFB
slice for cute.gemm.

Stage 0 still hangs at tmem.wait_for_alloc bar.sync 1, 320; committing
state for further debugging.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
limin2021 added 8 commits May 8, 2026 07:59
…ble next_n=3

Three independent fixes to the CuTe DSL FP4 paged MQA logits kernel and its
host wrapper:

1. SF Q TMA barrier tx_count was undersized.
   The TMA descriptor for SF Q is built with tile = N_padded (from
   sf_q_smem_layout_staged), and the host wrapper pads GMEM to N_padded
   bytes. So TMA actually fetches N_padded * 4 bytes, not N * 4. Setting
   the barrier transaction count to N * 4 caused the consumer (UMMA warp)
   to start the in-place SMEM transpose + UTCCP before TMA finished
   writing the trailing pad bytes — a race that produced ~40% flaky hangs
   for next_n=1 (where N=64 < N_padded=128). Fix: tx_count = N_padded * 4.
   Verified with 10/10 hang-repro runs all PASS in ~3.7s each.

2. Enable next_n=3 by skipping the per-fragment power-of-2 round.
   utils.get_num_tmem_alloc_cols(..., rounding=True) rounded acc cols
   192 -> 256 for next_n=3, pushing raw_total to 544 (>512 cap). The
   helper-level round is for bookkeeping only; the actual hardware
   alloc happens at the kernel-level tmem.allocate(num_tmem_alloc_cols_total)
   which has its own power-of-2 round on the total. Per-fragment counts
   only need 32-byte alignment (192 = 32*6 OK). Pass rounding=False.

3. Drop redundant .contiguous() in host wrapper to fix B>1 stride mismatch.
   The compile-time fake tensors use stride_order=(1, 0, 2) (Q) and
   (0, 1) (sf_q, weights), making the innermost stride 1 and outer
   strides B-independent. The runtime wrapper's .permute(1, 2, 0).contiguous()
   repacked memory into a contiguous layout where B becomes innermost,
   producing strides like (half_D*B, B, 1) — mismatching the fake's
   (half_D, 1, N*half_D). Removing .contiguous() keeps the permute as
   a view with the expected B-independent strides. Same change for
   sf_q and weights. SF Q padding restructured to pad in the (B, N) layout
   first (preserves transpose-then-view stride pattern). Mirrors the FP8
   wrapper which already does it correctly.

Result: full sweep over (batch_size, next_n, ctx) goes from
7 PASS / 20 FAIL to 17 PASS / 10 FAIL; all remaining failures are
numerical mismatches at large ctx + B>1 + next_n>=2 (separate
pre-existing bugs, to be investigated separately).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The DSL FP4 kernel splits DeepGEMM's single UMMA warp into two
(umma_warp_0 and umma_warp_1). Both warps' MMAs read the same SFB
(Q SF) TMEM region, but only warp 0 writes it on q_idx transitions.
The original code relied on q_pipeline.consumer_wait for cross-warp
visibility, which only orders TMA->SMEM and not warp-0->warp-1 TMEM
writes. Two distinct hazards followed:

1. Forward visibility. umma_warp_1 could fire its MMA before
   umma_warp_0's async s2t to TMEM SFB landed, reading stale SFB
   from the prior batch. Manifested as deterministic numerical
   mismatches in B>1 cases at large ctx (e.g. ctx=32768 + B=4
   always failed).

2. Backward overwrite. SFB TMEM is a single region (no staging).
   Once the forward fix synced warp 1 against the new SFB, warp 0
   could still race ahead to the next q transition and overwrite
   SFB while warp 1's previous-batch MMA was still reading it.
   Manifested as flaky B=16 failures in the extended sweep
   (5/2/2 across three runs).

Fix: add NamedBarrier(id=2, 64 threads) used at two points per
umma warp loop iteration:
- After umma_warp_0's SFB s2t (with fence_view_async_tmem_store)
  and before umma_warp_1's MMA on q transition: forward visibility.
- At the end of every iteration in both umma warps after
  producer_commit: lock-steps the two warps so warp 0's next-iter
  SFB write cannot precede warp 1's current-iter MMA commit.

Also expand the test sweep to exercise paths previously gated off:
- avg_ctx adds 8192 and 16384 (was {256, 4096, 32768})
- num_epi_subtiles enables 2 and 4 (was 1 only)
- fix_length covers both True and False (was True only)

Result: 270/270 pass x 3 consecutive runs in the main test;
90/90 pass in the multi-block test (phys_block_kv in {32, 64}).
Per-tile barrier overhead is negligible vs MMA + TMEM ops
(~12s -> ~12s in the 90-case sweep, ~35s for the 270-case sweep).

DeepGEMM avoids this entirely by issuing both groups' UMMAs from
a single UMMA warp; sm100_fp4_paged_mqa_logits.cuh:337-365 has the
'for i in kNumMathWarpGroups' inner loop in one warp.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Add phys_block_kv parametrize ([32, 64, 128]) to the main test so the
paged multi-block TMA path (phys_block_kv < compute tile = 128,
NUM_BLOCKS_PER_MMA > 1) is exercised alongside the single-block path
in one place. Remove the now-redundant
test_cute_dsl_fp4_paged_mqa_logits_multi_block function — its only
unique coverage was batch_size=32, which is scheduler stress rather
than a new code path; the main test's [1, 4, 16] already exercises
the same kernel logic.

Also narrow the active epi_dtype/output_dtype set to only (fp32, bf16)
for now to keep the sweep size manageable; the other four dtype
combinations are list-commented and trivially re-enabled.

Tested: 810/810 PASS in ~102s (5 ctx x 3 phys_block_kv x 1 dtype x
3 subtile x 3 next_n x 3 batch_size x 2 fix_length).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Match DeepGEMM by passing real N (= next_n * num_heads) as the SF Q
TMA descriptor tile shape instead of N_padded, so the host wrapper no
longer needs to zero-pad sf_q from N to N_padded on every kernel call.
This eliminates a per-launch torch.zeros + slice copy and aligns the
DSL kernel's SF Q layout with DeepGEMM's reference
(sm100_fp4_paged_mqa_logits.cuh:202, tma::copy<kRealNumSFQAtom, 1, 0>).

The DSL TMA helper (cpasync.make_tiled_tma_atom) requires the
smem_layout passed in to be cosize-symmetric with cta_tiler -- it
rejects "SMEM tile > TMA tile" outright. Trick: build the TMA atom
with a smaller (N, num_q_stages) smem layout that matches cta_tiler
= (N,), but allocate the actual SMEM at (N_padded, num_q_stages) for
UTCCP's 128-aligned atom requirement. The two layouts share the same
SMEM iterator via a logical view (sSF_Q_for_tma) used only at
tma_partition; UTCCP and downstream paths still see the full
N_padded SMEM region. Positions [N, N_padded) in each stage are
left as garbage post-launch -- never read because UMMA_N = N and
the math epilogue only writes acc cols [0, N). Same invariant as
DeepGEMM's "kNumSFQAtom-padded SMEM but kRealNumSFQAtom-fetching
TMA" pattern.

Also barrier tx_count drops from N_padded * 4 bytes to N * 4 bytes
to match the actual TMA transfer.

Test sweep updated: enable (fp32, bf16) Stage 2 cast path; the other
four dtype combos remain list-commented. Stage 1 packed FMA paths
(bf16/bf16, fp16/fp16) have a separate pre-existing accuracy issue
unrelated to SF Q layout (looks like ELEM_TOL atol set tighter than
the actual 64-head packed-FMA accumulation precision); to be
investigated separately.

Tested: 810/810 PASS in ~106s
(5 ctx x 3 phys_block_kv x 3 num_epi_subtiles x 3 next_n x
3 batch_size x 2 fix_length x 1 dtype).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Two small cleanups in the kernel constructor:

1. Replace the manual `(32, 64, 128, 256, 512)` tuple lookup for
   power-of-two TMEM rounding with the same one-line formula used
   inside utils.get_num_tmem_alloc_cols(rounding=True):

       self.num_tmem_alloc_cols_total =
           max(1 << math.ceil(math.log2(raw_total)), 32)

   Functionally identical (raw_total <= 512 always for next_n <= 3),
   but reads as the actual rounding semantics and stops needing
   maintenance if the upper bound ever changes. The DSL helper
   itself can't be called here because we don't yet have a tmem
   tensor handle at this point; we only have the raw column count.

2. Replace the now-stale TMA SF Q comment that still talked about
   "host pad while host pad still exists" (host pad was removed in
   the previous commit) with a concise statement of the invariant:
   SMEM positions [N, N_padded) are garbage, propagate through UTCCP
   into TMEM SFB cols >= N, but are never read by MMA (UMMA_N = N)
   nor written by the epilogue (acc cols [0, N) only).

No functional change. Smoke test on next_n in {1, 3} passes.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Previous bound `64 if next_n == 1 else 52` was overly conservative:
SASS spill check (cuobjdump --dump-sass | grep LDL/STL) under the 240-reg
math warpgroup shows next_n=1,2 fit 64 weights/slot with zero spill;
only next_n=3 spills above 56. Update to `56 if next_n == 3 else 64`.

next_n=2 now caches 64/64 heads in registers (was 52/64), eliminating
the SMEM fallback loop entirely and reducing LDS.128 traffic in the
epilogue. next_n=3 increases from 52 to 56 (max safe value, still a
multiple of 4 for the packed-FMA h_g..h_g+3 layout).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
context_lens must be 2D and block_kv must be 64 (SPLIT_KV alignment with
the compute kernel's hardcoded 256, see DSA_DG_C491439E_MIGRATION_NOTES.md).

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@limin2021 limin2021 requested review from a team as code owners May 9, 2026 04:26
…xer_gemm_fp4_post_merge

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>

# Conflicts:
#	tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
#	tensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/__init__.py
@limin2021 limin2021 changed the title Add dsl indexer gemm fp4 post merge [TRTLLM-34871][feat] Add cute dsl FP4 paged MQA logits decode kernel May 9, 2026
@limin2021 limin2021 changed the title [TRTLLM-34871][feat] Add cute dsl FP4 paged MQA logits decode kernel [TRTLLM-35237][feat] Add cute dsl FP4 paged MQA logits decode kernel May 9, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 9, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds two PyTorch custom op implementations of paged MQA logits computation for Blackwell SM100: one for FP8-quantized queries and one for FP4-packed queries with per-token scaling. Each variant includes dtype validation, TVM-FFI kernel compilation with caching, and fake tensor support. A comprehensive test module validates the FP4 variant against a pure PyTorch reference with quantization utilities and numerical tolerances.

Changes

Custom Ops & Kernel Module

Layer / File(s) Summary
Dtype Validation
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
FP8 variant validates q as FP8 E4M3FN, kv_fused as uint8, weights as float32, metadata as int32, and constrains epi_dtype, acc_dtype, output_dtype to float16/float32. FP4 variant validates q as uint8 (FP4-packed), sf_q as int32, kv_fused as uint8, weights as float32, and constrains epi_dtype, output_dtype to float16/bfloat16/float32.
FP8 Kernel Runner
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
CuteDSLPagedMQALogitsRunner compiles and caches FP8 kernel via TVM-FFI with fake tensors, reshapes q to [N, D, B] layout, reshapes weights to [N, B], flattens kv_fused, allocates aligned output buffer, and executes compiled kernel. Custom op wrapper validates SM100 support and dispatches to runner.
FP4 Kernel Runner
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
CuteDSLFP4PagedMQALogitsRunner compiles and caches FP4 kernel via TVM-FFI, reshapes q to [N, D/2, B] without forcing contiguity, reshapes sf_q to [N, B], reshapes/casts weights based on epi_dtype, flattens kv_fused, allocates aligned output, and executes kernel with quantized inputs. Custom op wrapper validates SM100 support.
Fake Registrations
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Both FP8 and FP4 variants register fake tensor handlers returning [B*next_n, max_context_len] tensors with output_dtype.
Module Exports
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/__init__.py
Imports FP4MQALogitsKernel and FP8MQALogitsKernel, exports both via __all__.

FP4 Test Suite

Layer / File(s) Summary
Test Infrastructure & Quantization
tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py
SM100/103-restricted test module with alignment helpers, UE8M0 scale-factor packing/unpacking, FP4 (e2m1) quantization/dequantization lookup, per-token FP4 casting with group-wise scaling and nibble packing, and FP4 unpacking with per-group scale application. kv_cache_cast_to_fp4 converts KV cache to fused FP4 layout.
Reference Implementation & Metrics
tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py
Pure PyTorch _ref_paged_mqa_logits computes paged-MQA logits with causal/context masking, ReLU activation, and per-head weighting. calc_diff provides cosine-style error metric. ELEM_TOL maps (epi_dtype, output_dtype) pairs to (atol, rtol) thresholds.
Parametrized Test & Validation
tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py
Test sweeps batch sizes, next_n, num\_heads=64, context lengths, physical KV block sizes, epilogue dtype combinations, and epi subtile counts. Setup quantizes Q/KV, builds block tables and metadata, runs reference and kernel, masks invalid positions, and asserts closeness via torch.testing.assert_close and cosine-diff threshold (0.02).

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description is missing critical sections. Only test commands and PR checklist template are present, without actual descriptions of changes, rationale, or test coverage details. Add a clear Description section explaining what changes were made and why, a Test Coverage section listing relevant tests that safeguard the changes, and complete the PR checklist items. See template for required sections.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main change: adding a CuTe DSL FP4 paged MQA logits decode kernel, with proper JIRA ticket reference and feature type tag.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py (2)

441-479: ⚡ Quick win

Gate verbose accuracy probe logs to avoid CI log bloat.

These prints run on most passing cases and can heavily spam logs with this parameter sweep. Consider printing only when an error threshold is exceeded.

Proposed guard
-    if elem_abs.numel() > 0:
+    if elem_abs.numel() > 0 and elem_abs.max().item() > max(5 * atol, 1.0):
         kernel_valid = logits_clean[valid]
         ref_valid = ref_clean[valid]
         print(
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`
around lines 441 - 479, The probe printouts (those printing
logits_clean/ref_clean, max/mean stats, and the large_err buckets) should be
gated so they only run when the error threshold is exceeded to avoid CI log
spam; change the logic around diff_abs, large_err and the subsequent prints so
that you compute diff_abs and large_err as now but only emit all the detailed
prints (the avg_ctx/epi/out prints, kernel/ref snippets, and bucket breakdowns
using mod32_buckets/mod8_buckets) when len(large_err) > 0 or len(large_err)
exceeds a configurable threshold (e.g., >0 or >N), or when a test-verbose
flag/env var is set; keep the minimal failing assert messages outside the gate
and reference the existing identifiers logits_clean, ref_clean, diff_abs,
large_err, mod32_buckets, and mod8_buckets when moving the prints into that
conditional.

54-67: ⚡ Quick win

Add explicit return type annotations to helper and test functions.

A few defs still omit return types (ceil_to_ue8m0, pack_ue8m0_to_int, kv_cache_cast_to_fp4, calc_diff, _ref_paged_mqa_logits, and the test signature). Please annotate them for consistency with repo typing rules.

Proposed typing cleanup
-def ceil_to_ue8m0(x: torch.Tensor):
+def ceil_to_ue8m0(x: torch.Tensor) -> torch.Tensor:
@@
-def pack_ue8m0_to_int(x: torch.Tensor):
+def pack_ue8m0_to_int(x: torch.Tensor) -> torch.Tensor:
@@
-def kv_cache_cast_to_fp4(x: torch.Tensor):
+def kv_cache_cast_to_fp4(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
@@
-def calc_diff(x: torch.Tensor, y: torch.Tensor):
+def calc_diff(x: torch.Tensor, y: torch.Tensor) -> float:
@@
-def _ref_paged_mqa_logits(
+def _ref_paged_mqa_logits(
@@
-):
+) -> torch.Tensor:

As per coding guidelines, "Python code should use type annotations for all function arguments and return types; return type None if function does not return".

Also applies to: 142-194, 280-290

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`
around lines 54 - 67, Add explicit return type annotations to the helper and
test functions mentioned: annotate ceil_to_ue8m0, unpack_ue8m0_from_int,
pack_ue8m0_to_int, kv_cache_cast_to_fp4, calc_diff, and _ref_paged_mqa_logits
with appropriate torch.Tensor return types (e.g., -> torch.Tensor) and annotate
the test function signature to return None (-> None) to satisfy repository
typing rules; update any imports if needed to reference torch in type hints and
keep signatures consistent with existing parameter types.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 6984-7016: In cute_dsl_fp4_paged_mqa_logits validate the
num_epi_subtiles parameter at the API boundary (inside the function before
calling CuteDSLFP4PagedMQALogitsRunner.forward) and raise a clear ValueError if
it is not one of the allowed values {1, 2, 4}; include the offending value in
the error message to aid debugging and prevent invalid integers from being
forwarded into kernel compilation/cache lookup.
- Around line 6751-6780: The helper _check_fp4_paged_mqa_logits_dtypes currently
only validates dtypes; add explicit shape/rank checks to fail fast before the
later reshapes in CuteDSLFP4PagedMQALogitsRunner.forward: verify q.ndim == 4,
sf_q.ndim == 3 and sf_q.shape == q.shape[:3], kv_fused.ndim >= 1 and
kv_fused.shape[-1] == q.shape[-1] + 4, weights.ndim == 2 and weights.shape ==
(q.shape[0] * q.shape[1], q.shape[2]) or the intended (B * next_n, H) pattern,
context_lens.shape == (q.shape[0],) and context_lens.dim() == 1, and
block_table.shape[0] == q.shape[0]; append clear error messages to errs and
raise as done for dtype issues so mis-shaped but same-numel tensors fail early.

In
`@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`:
- Around line 203-204: In _ref_paged_mqa_logits, the tuple unpacking binds
num_heads and num_block but they are unused; change those bindings to
underscore-prefixed names (e.g., _num_heads and _num_block) or use plain
underscores to avoid unused-variable lint warnings; update the unpacking lines
that deconstruct q.size() and kv_cache.size() accordingly so only needed names
(batch_size, next_n, dim, block_size, _) remain as meaningful identifiers.

---

Nitpick comments:
In
`@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`:
- Around line 441-479: The probe printouts (those printing
logits_clean/ref_clean, max/mean stats, and the large_err buckets) should be
gated so they only run when the error threshold is exceeded to avoid CI log
spam; change the logic around diff_abs, large_err and the subsequent prints so
that you compute diff_abs and large_err as now but only emit all the detailed
prints (the avg_ctx/epi/out prints, kernel/ref snippets, and bucket breakdowns
using mod32_buckets/mod8_buckets) when len(large_err) > 0 or len(large_err)
exceeds a configurable threshold (e.g., >0 or >N), or when a test-verbose
flag/env var is set; keep the minimal failing assert messages outside the gate
and reference the existing identifiers logits_clean, ref_clean, diff_abs,
large_err, mod32_buckets, and mod8_buckets when moving the prints into that
conditional.
- Around line 54-67: Add explicit return type annotations to the helper and test
functions mentioned: annotate ceil_to_ue8m0, unpack_ue8m0_from_int,
pack_ue8m0_to_int, kv_cache_cast_to_fp4, calc_diff, and _ref_paged_mqa_logits
with appropriate torch.Tensor return types (e.g., -> torch.Tensor) and annotate
the test function signature to return None (-> None) to satisfy repository
typing rules; update any imports if needed to reference torch in type hints and
keep signatures consistent with existing parameter types.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2fbf6591-7936-4c31-b269-c0e62650d129

📥 Commits

Reviewing files that changed from the base of the PR and between c0d86d0 and 8a42a96.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/__init__.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/paged_mqa_logits/fp4_paged_mqa_logits.py
  • tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py

Comment on lines +6751 to +6780
def _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype):
errs = []
if q.dtype != torch.uint8:
errs.append(f"q must be uint8 (FP4 packed), got {q.dtype}")
if sf_q.dtype != torch.int32:
errs.append(f"sf_q must be int32, got {sf_q.dtype}")
if kv_fused.dtype != torch.uint8:
errs.append(f"kv_fused must be uint8, got {kv_fused.dtype}")
if weights.dtype != torch.float32:
errs.append(f"weights must be float32, got {weights.dtype}")
if context_lens.dim() != 1:
errs.append(f"context_lens must be 1D, got {context_lens.dim()}D")
if context_lens.dtype != torch.int32:
errs.append(f"context_lens must be int32, got {context_lens.dtype}")
if block_table.dtype != torch.int32:
errs.append(f"block_table must be int32, got {block_table.dtype}")
if schedule_meta.dtype != torch.int32:
errs.append(
f"schedule_meta must be int32, got {schedule_meta.dtype}")
for name, dt in [("epi_dtype", epi_dtype),
("output_dtype", output_dtype)]:
if dt not in (torch.float16, torch.bfloat16, torch.float32):
errs.append(
f"{name} must be float16, bfloat16, or float32, got {dt}")
if errs:
raise ValueError("FP4 Paged MQA Logits dtype errors:\n " +
"\n ".join(errs))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate FP4 paged-MQA shapes before the reshapes.

CuteDSLFP4PagedMQALogitsRunner.forward() later reinterprets q, sf_q, and weights as [B, N, ...] purely via reshape(). Right now this helper only checks dtypes, so a same-numel but differently ordered tensor can pass validation and silently scramble the token/head mapping instead of failing fast. Please validate the expected ranks and cross-tensor shapes here (q.ndim == 4, sf_q.shape == q.shape[:3], weights.shape == (B * next_n, H), context_lens.shape == (B,), block_table.shape[0] == B, kv_fused.shape[-1] == q.shape[-1] + 4, etc.).

Suggested guard pattern
 def _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
                                        context_lens, block_table,
                                        schedule_meta, epi_dtype,
                                        output_dtype):
     errs = []
@@
     if schedule_meta.dtype != torch.int32:
         errs.append(
             f"schedule_meta must be int32, got {schedule_meta.dtype}")
+    if q.dim() != 4:
+        errs.append(f"q must be 4D [B, next_n, H, D//2], got {q.dim()}D")
+    else:
+        batch_size, next_n, num_heads, half_head_dim = q.shape
+        if sf_q.shape != (batch_size, next_n, num_heads):
+            errs.append(
+                f"sf_q must have shape {(batch_size, next_n, num_heads)}, got {tuple(sf_q.shape)}"
+            )
+        if weights.shape != (batch_size * next_n, num_heads):
+            errs.append(
+                f"weights must have shape {(batch_size * next_n, num_heads)}, got {tuple(weights.shape)}"
+            )
+        if context_lens.shape != (batch_size,):
+            errs.append(
+                f"context_lens must have shape {(batch_size,)}, got {tuple(context_lens.shape)}"
+            )
+        if block_table.dim() != 2 or block_table.shape[0] != batch_size:
+            errs.append(
+                f"block_table must be 2D with batch dimension {batch_size}, got {tuple(block_table.shape)}"
+            )
+        if kv_fused.dim() != 4 or kv_fused.shape[2] != 1 or kv_fused.shape[3] != half_head_dim + 4:
+            errs.append(
+                "kv_fused must have shape [num_blocks, phys_block_kv, 1, D//2 + 4]"
+            )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 6751 -
6780, The helper _check_fp4_paged_mqa_logits_dtypes currently only validates
dtypes; add explicit shape/rank checks to fail fast before the later reshapes in
CuteDSLFP4PagedMQALogitsRunner.forward: verify q.ndim == 4, sf_q.ndim == 3 and
sf_q.shape == q.shape[:3], kv_fused.ndim >= 1 and kv_fused.shape[-1] ==
q.shape[-1] + 4, weights.ndim == 2 and weights.shape == (q.shape[0] *
q.shape[1], q.shape[2]) or the intended (B * next_n, H) pattern,
context_lens.shape == (q.shape[0],) and context_lens.dim() == 1, and
block_table.shape[0] == q.shape[0]; append clear error messages to errs and
raise as done for dtype issues so mis-shaped but same-numel tensors fail early.

Comment on lines +6984 to +7016
def cute_dsl_fp4_paged_mqa_logits(
q: torch.Tensor,
sf_q: torch.Tensor,
kv_fused: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_table: torch.Tensor,
schedule_meta: torch.Tensor,
max_context_len: int,
num_epi_subtiles: int = 1,
epi_dtype: torch.dtype = torch.float32,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.")
_check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype)
return CuteDSLFP4PagedMQALogitsRunner.forward(
q,
sf_q,
kv_fused,
weights,
context_lens,
block_table,
schedule_meta,
max_context_len,
num_epi_subtiles=num_epi_subtiles,
epi_dtype=epi_dtype,
output_dtype=output_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reject unsupported num_epi_subtiles values at the API boundary.

The docstring constrains num_epi_subtiles to 1, 2, or 4, but the custom op currently accepts any integer and forwards it straight into kernel compilation/cache lookup. Invalid values will surface as a much harder-to-diagnose JIT/kernel failure.

Suggested guard
     if not is_sm_100f():
         raise ValueError(
             f"CuteDSL: SM version {get_sm_version()} is not supported. "
             f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.")
+    if num_epi_subtiles not in (1, 2, 4):
+        raise ValueError(
+            f"num_epi_subtiles must be one of (1, 2, 4), got {num_epi_subtiles}"
+        )
     _check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
                                        context_lens, block_table,
                                        schedule_meta, epi_dtype,
                                        output_dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cute_dsl_fp4_paged_mqa_logits(
q: torch.Tensor,
sf_q: torch.Tensor,
kv_fused: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_table: torch.Tensor,
schedule_meta: torch.Tensor,
max_context_len: int,
num_epi_subtiles: int = 1,
epi_dtype: torch.dtype = torch.float32,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.")
_check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype)
return CuteDSLFP4PagedMQALogitsRunner.forward(
q,
sf_q,
kv_fused,
weights,
context_lens,
block_table,
schedule_meta,
max_context_len,
num_epi_subtiles=num_epi_subtiles,
epi_dtype=epi_dtype,
output_dtype=output_dtype)
def cute_dsl_fp4_paged_mqa_logits(
q: torch.Tensor,
sf_q: torch.Tensor,
kv_fused: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_table: torch.Tensor,
schedule_meta: torch.Tensor,
max_context_len: int,
num_epi_subtiles: int = 1,
epi_dtype: torch.dtype = torch.float32,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP4 Paged MQA Logits only supports SM 100 family.")
if num_epi_subtiles not in (1, 2, 4):
raise ValueError(
f"num_epi_subtiles must be one of (1, 2, 4), got {num_epi_subtiles}"
)
_check_fp4_paged_mqa_logits_dtypes(q, sf_q, kv_fused, weights,
context_lens, block_table,
schedule_meta, epi_dtype,
output_dtype)
return CuteDSLFP4PagedMQALogitsRunner.forward(
q,
sf_q,
kv_fused,
weights,
context_lens,
block_table,
schedule_meta,
max_context_len,
num_epi_subtiles=num_epi_subtiles,
epi_dtype=epi_dtype,
output_dtype=output_dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 6984 -
7016, In cute_dsl_fp4_paged_mqa_logits validate the num_epi_subtiles parameter
at the API boundary (inside the function before calling
CuteDSLFP4PagedMQALogitsRunner.forward) and raise a clear ValueError if it is
not one of the allowed values {1, 2, 4}; include the offending value in the
error message to aid debugging and prevent invalid integers from being forwarded
into kernel compilation/cache lookup.

Comment on lines +203 to +204
batch_size, next_n, num_heads, dim = q.size()
num_block, block_size, _, dim = kv_cache.size()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Remove unused unpacked variables in _ref_paged_mqa_logits.

num_heads and num_block are unpacked but never used (RUF059). Rename to underscored bindings (or drop them) to keep lint clean.

Proposed fix
-    batch_size, next_n, num_heads, dim = q.size()
-    num_block, block_size, _, dim = kv_cache.size()
+    batch_size, next_n, _num_heads, dim = q.size()
+    _num_block, block_size, _, dim = kv_cache.size()
🧰 Tools
🪛 Ruff (0.15.12)

[warning] 203-203: Unpacked variable num_heads is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


[warning] 204-204: Unpacked variable num_block is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/attention/sparse/test_cute_dsl_fp4_paged_mqa_logits.py`
around lines 203 - 204, In _ref_paged_mqa_logits, the tuple unpacking binds
num_heads and num_block but they are unused; change those bindings to
underscore-prefixed names (e.g., _num_heads and _num_block) or use plain
underscores to avoid unused-variable lint warnings; update the unpacking lines
that deconstruct q.size() and kv_cache.size() accordingly so only needed names
(batch_size, next_n, dim, block_size, _) remain as meaningful identifiers.

limin2021 added 3 commits May 9, 2026 04:45
…/scripts/

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…n correctness check

Move the remaining standalone-driver pieces from fp4_paged_mqa_logits.py
to run_fp4.py (FP4 quant helpers, _compute_schedule_metadata, the
compile cache + fake-tensor _compile_fp4_kernel, host wrapper
fp4_paged_mqa_logits, and the pure-torch _ref_paged_mqa_logits).
The kernel file now contains only the FP4MQALogitsKernel class.

Replace the runner's cosine-only PASS criterion with the element-wise
atol+rtol check used by the unit test (ELEM_TOL table). Cosine diff is
kept as a supplementary global metric. Output now reports max_abs,
mean_abs, atol, rtol alongside the diff.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…scripts/

Mirror run_fp4.py with FP8-specific data prep / reference inlined from
the unit test (test_cute_dsl_fp8_paged_mqa_logits.py) and the compile
+ dispatch path inlined from cute_dsl_custom_ops.py
(CuteDSLPagedMQALogitsRunner). Schedule metadata reuses the same pure
Python implementation as run_fp4.py — both kernels share
compute_block_kv=128 + NUM_MATH_WG=2 → SPLIT_KV=256.

Correctness check matches the unit test: element-wise atol+rtol from
output_dtype (fp32: 5e-5/1e-5, fp16: 1e-3/1e-3) plus cosine diff as a
supplementary global metric. The runner prints diagnostic stats
(max_abs, mean_abs, atol, rtol, diff) and PASS/FAIL without raising,
matching the FP4 runner's "show all results" semantics.

Not wired into CI — same convention as the other launch scripts under
tests/scripts/cute_dsl_kernels/.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant