Skip to content

[TRTLLM-10407][perf] Add cute dsl single pass multi cta cluster topk#12339

Closed
limin2021 wants to merge 348 commits intoNVIDIA:mainfrom
limin2021:add_cute_dsl_single_pass_multi_cta_cluster_topk
Closed

[TRTLLM-10407][perf] Add cute dsl single pass multi cta cluster topk#12339
limin2021 wants to merge 348 commits intoNVIDIA:mainfrom
limin2021:add_cute_dsl_single_pass_multi_cta_cluster_topk

Conversation

@limin2021
Copy link
Copy Markdown
Collaborator

@limin2021 limin2021 commented Mar 19, 2026

Summary by CodeRabbit

  • New Features

    • Added cluster-accelerated top-k decoding optimization for Blackwell hardware, improving performance for token selection in decoding operations.
  • Bug Fixes

    • Improved dispatch logic in mixture-of-experts communication to correctly handle quantization configurations.
      Fix the following bugs when running trtllm-eval with TRTLLM_FORCE_COMM_METHOD=DEEPEP
   File "/home/lmin/scratch/trtlm-study/tensorrt-llm-new/tensorrt_llm/_torch/modules/fused_moe/interface.py", line 770, in forward
    return self.forward_impl(
           ^^^^^^^^^^^^^^^^^^
  File "/home/lmin/scratch/trtlm-study/tensorrt-llm-new/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py", line 466, in forward_impl
    outputs = self._forward_single_chunk(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lmin/scratch/trtlm-study/tensorrt-llm-new/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py", line 551, in
_forward_single_chunk
    outputs = self._forward_chunk_impl(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lmin/scratch/trtlm-study/tensorrt-llm-new/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py", line 734, in
_forward_chunk_impl
    x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch(
                                                        ^^^^^^^^^^^^^^^^^^^
  File "/home/lmin/scratch/trtlm-study/tensorrt-llm-new/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py", line 208, in dispatch
    (hidden_states, hidden_states_sf),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 2)
  • Tests
    • Added test coverage for cluster-based top-k decoding functionality.

Description

Add a cluster-accelerated single-pass multi-CTA radix top-k kernel for Blackwell (SM100+).

Replaces global memory atomics + arrival counter polling with:

  • Cluster barriers (cluster_arrive_relaxed + cluster_wait) for inter-CTA synchronization
  • DSMEM (distributed shared memory) for histogram merging across CTAs in the same cluster

This eliminates the triple-buffered global histogram, the arrival counter, and all GPU-scope acquire/release PTX for barriers. Only the output counter (1 int32 in GMEM) is retained for atomicAdd during output collection.

Key changes

  • New kernel: single_pass_multi_cta_radix_topk_cluster.py
  • New runner: CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner with SM-aware heuristics and hardware max cluster size clamping
  • Test coverage for the cluster kernel path
  • Minor fix to DeepEP post-quantization dispatch logic

Perf

Comparison of different multi-CTA kernels on fixed-length input (dtype=float32, top_k=2048, B200).

Kernels compared:

  • dist: distributed (global memory atomics)
  • cluster: cluster-accelerated (DSMEM + cluster barriers) — this PR
  • mc_dyn / mc_sta: multi-CTA dynamic/static scheduling
  • fi: FlashInfer reference

cl/dist < 1.0 means cluster is faster than distributed.

batch=1
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 7.96 8.10 12.67 12.24 11.99 1.02x 1.59x 1.54x 1.51x
16384 12.36 12.49 15.93 15.47 16.55 1.01x 1.29x 1.25x 1.34x
32768 19.26 12.58 21.07 20.71 25.23 0.65x 1.09x 1.07x 1.31x
65536 19.29 13.06 22.89 22.45 32.88 0.68x 1.19x 1.16x 1.70x
131072 19.08 14.44 27.14 26.59 37.01 0.76x 1.42x 1.39x 1.94x
262144 19.01 17.74 36.17 35.73 39.92 0.93x 1.90x 1.88x 2.10x
batch=4
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.06 8.14 12.78 12.32 12.02 1.01x 1.59x 1.53x 1.49x
16384 12.48 12.55 16.13 15.75 16.60 1.01x 1.29x 1.26x 1.33x
32768 20.00 12.60 21.81 21.48 25.33 0.63x 1.09x 1.07x 1.27x
65536 19.97 13.07 23.51 23.00 32.99 0.65x 1.18x 1.15x 1.65x
131072 19.46 15.01 28.09 27.52 37.01 0.77x 1.44x 1.41x 1.90x
262144 19.82 18.33 37.15 36.55 40.18 0.92x 1.87x 1.84x 2.03x
batch=8
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.07 8.18 12.85 12.45 12.06 1.01x 1.59x 1.54x 1.49x
16384 12.49 12.57 16.84 16.40 16.64 1.01x 1.35x 1.31x 1.33x
32768 20.36 12.63 21.77 21.46 25.34 0.62x 1.07x 1.05x 1.25x
65536 20.39 13.06 23.63 23.08 33.44 0.64x 1.16x 1.13x 1.64x
131072 20.18 17.29 28.38 27.80 37.08 0.86x 1.41x 1.38x 1.84x
262144 23.36 22.13 37.37 36.72 40.51 0.95x 1.60x 1.57x 1.73x
batch=16
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.09 8.16 12.89 12.47 12.07 1.01x 1.59x 1.54x 1.49x
16384 12.50 12.57 16.83 16.44 16.65 1.01x 1.35x 1.32x 1.33x
32768 20.55 12.69 22.43 21.97 25.38 0.62x 1.09x 1.07x 1.23x
65536 20.51 17.27 23.85 23.34 33.62 0.84x 1.16x 1.14x 1.64x
131072 23.81 21.78 28.47 27.97 37.46 0.91x 1.20x 1.18x 1.57x
262144 30.02 43.38 49.12 48.43 40.98 1.45x 1.64x 1.61x 1.37x
batch=32
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.12 8.19 12.94 12.53 12.10 1.01x 1.59x 1.54x 1.49x
16384 12.55 12.60 16.92 16.49 16.71 1.00x 1.35x 1.31x 1.33x
32768 20.78 12.92 22.81 22.44 25.45 0.62x 1.10x 1.08x 1.22x
65536 24.33 16.69 24.83 24.30 33.65 0.69x 1.02x 1.00x 1.38x
131072 30.61 23.38 40.18 39.48 37.49 0.76x 1.31x 1.29x 1.23x
262144 57.78 64.47 72.70 71.70 77.82 1.12x 1.26x 1.24x 1.35x
batch=64
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.16 8.24 13.07 12.65 12.18 1.01x 1.60x 1.55x 1.49x
16384 12.59 12.68 17.16 16.80 16.78 1.01x 1.36x 1.33x 1.33x
32768 20.08 20.21 23.08 22.57 25.55 1.01x 1.15x 1.12x 1.27x
65536 31.63 24.23 33.11 32.58 34.22 0.77x 1.05x 1.03x 1.08x
131072 58.55 44.90 63.65 62.64 72.43 0.77x 1.09x 1.07x 1.24x
262144 112.85 127.51 110.14 108.41 116.47 1.13x 0.98x 0.96x 1.03x
batch=128
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 8.25 8.38 13.91 13.49 12.29 1.02x 1.69x 1.63x 1.49x
16384 12.73 12.88 17.38 16.95 16.97 1.01x 1.37x 1.33x 1.33x
32768 20.22 20.52 31.44 30.97 25.78 1.01x 1.55x 1.53x 1.27x
65536 60.68 46.40 51.48 50.86 65.29 0.76x 0.85x 0.84x 1.08x
131072 103.06 107.67 100.83 99.05 108.45 1.04x 0.98x 0.96x 1.05x
262144 195.09 252.28 202.74 199.97 208.27 1.29x 1.04x 1.03x 1.07x
batch=256
num_tokens dist(us) cluster(us) mc_dyn(us) mc_sta(us) fi(us) cl/dist mcd/dist mcs/dist fi/dist
8192 14.96 15.21 18.30 17.83 22.71 1.02x 1.22x 1.19x 1.52x
16384 23.81 24.32 25.86 25.42 31.90 1.02x 1.09x 1.07x 1.34x
32768 38.65 39.58 50.89 50.44 49.36 1.02x 1.32x 1.31x 1.28x
65536 118.77 90.64 79.43 78.90 127.80 0.76x 0.67x 0.66x 1.08x
131072 208.70 274.94 148.59 148.32 225.11 1.32x 0.71x 0.71x 1.08x
262144 349.11 531.65 286.31 285.79 373.13 1.52x 0.82x 0.82x 1.07x

Key observations

  • Small batch (1-32) + large vocab (32K-256K): cluster achieves 0.62x-0.76x vs distributed — up to 38% faster due to DSMEM communication advantage
  • Large batch (128-256) + large vocab: cluster is slower (1.29x-1.52x) due to co-scheduling constraints — cluster CTAs must reside on the same GPC, limiting SM utilization under high occupancy
  • Cluster is the fastest kernel for the typical decode scenario (small batch + large vocab)

Test Coverage

  • pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py -k "test_cute_dsl_topk_decode_single_pass_multi_cta_cluster" — 180 tests passed (3 dtypes x 3 vocab sizes x 4 batch/next_n combos x 5 ctas_per_group configs)

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

yingguo-trt and others added 30 commits March 7, 2026 18:30
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…mand help to work. (NVIDIA#11722)

Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…ing hang with asymmetric PP/TP (NVIDIA#11789)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…1779)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
full:DGX_H100/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py::test_sharding[1-1] SKIP (https://nvbugs/5936322) is duplicated by unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py::test_sharding[1-1] SKIP (https://nvbugs/5875203).

Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
…p-k varlen

Add dynamic multi-CTA scheduling mode that assigns CTAs proportionally
to each row's actual sequence length, avoiding wasted work on short rows.

Key changes:
- ComputeDynamicCTAOffsets: CuTE DSL kernel computing row_cta_offsets and
  row_output_offsets via parallel CTA counting + sequential prefix sum
- FilteredTopKKernelVarlenDecode: 1D grid with binary search task mapping
  when enable_dynamic_multi_cta=True; varlen merge input support
- CuteDSLTopKDecodeMultiCTARunner: dynamic=True/False mode with upper-bound
  grid allocation to avoid DtoH sync
- Unit test updated for dynamic mode coverage

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…c multi-CTA top-k

Eliminate the separate ComputeDynamicCTAOffsets kernel by computing the
prefix sum of per-row CTA counts directly in shared memory within the
main filtered_topk_kernel. This reduces the dynamic multi-CTA path from
3 CUDA kernel launches to 2.

Key changes:
- filtered_topk_kernel: add in-kernel block_prefix_sum to build
  s_row_cta_offsets in shared memory, replacing global memory lookups
- Remove row_cta_offsets and row_output_offsets from kernel signatures
- Remove ComputeDynamicCTAOffsets kernel compilation and invocation
- Add dynamic parameter to cute_dsl_indexer_topk_decode for easy
  static/dynamic comparison

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…id for dynamic multi-CTA top-k

Switch dynamic multi-CTA from 1D grid + shared-memory prefix sum + binary
search to a simple 2D grid (num_rows, num_ctas_per_row) with per-CTA early
exit. This eliminates ~80 lines of prefix-sum/binary-search logic, removes
the 512-row limit, reduces shared memory usage, and simplifies the host side
by reusing the merge buffer. Also updates auto-select heuristics (bf16/fp16
threshold to 131072, SM utilization check to <25%) and makes dynamic=True
the default.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
…with first-task direct assignment

Improve the persistent dynamic scheduling by assigning the first task
directly via block index (avoiding an unnecessary atomic), and expose
load_balance parameter through the custom op API with test coverage.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…#11957)

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
… mPrevBlock/mNextBlocks with lookup-node pointers. (NVIDIA#11919)

Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
… add overview doc (NVIDIA#11291)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
…lclose_to_hf failure (NVIDIA#10191)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
…k kernel

Add FlashInfer-style fused multi-CTA distributed radix top-k kernel that
cooperatively finds the global pivot via multi-round radix select with
global histogram merging. Single kernel launch, no intermediate buffer,
no merge kernel.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
… on multinode failure (NVIDIA#11905)

Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
…#11907)

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…efault t… (NVIDIA#12007)

Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 19, 2026

📝 Walkthrough

Walkthrough

This PR introduces cluster-accelerated single-pass multi-CTA radix top-k kernel for Blackwell GPUs, adding a new kernel implementation with a corresponding custom ops runner class and test coverage. A minor fix updates post-quant dispatch logic in DeepEP to check scale factor availability.

Changes

Cohort / File(s) Summary
Cluster Top-K Kernel & Runner
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/single_pass_multi_cta_radix_topk_cluster.py, tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
New cluster-synchronized radix top-k kernel for multi-CTA execution with DSMEM cluster synchronization and merged histogram computation. New CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner class with compilation caching, SM-aware heuristics, and GPU buffer management.
Cluster Top-K Test
tests/unittest/_torch/thop/parallel/test_indexer_topk.py
Added test function test_cute_dsl_topk_decode_single_pass_multi_cta_cluster with parameterization for cluster runner path validation.
DeepEP Post-Quant Fix
tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py
Updated dispatch control flow to require hidden_states_sf non-None before selecting post-quant path; moved scale factor conversion logic inside guarded post-quant branch.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host/Launcher
    participant CTA0 as CTA 0<br/>(Group Lead)
    participant CTAi as CTAs 1..N<br/>(In Cluster)
    participant GlobalMem as Global Memory
    participant SharedMem as Shared Memory
    participant DSMem as DSMEM<br/>(Cluster-wide)

    Host->>GlobalMem: Load input_values, seqlen
    Host->>Host: Compute num_groups, total_ctas
    Host->>CTA0: Launch cluster kernel<br/>(ctas_per_group > 1)

    rect rgba(0, 100, 200, 0.5)
    note over CTA0,CTAi: Per-Row Radix Round (2-4 rounds)
    CTA0->>SharedMem: Build local 256-bin<br/>histogram
    CTAi->>SharedMem: Build local 256-bin<br/>histogram
    CTA0->>CTA0: Arrive at cluster<br/>barrier
    CTAi->>CTAi: Arrive at cluster<br/>barrier
    CTA0->>DSMem: Merge CTAs' histograms<br/>via DSMEM reads
    DSMem->>SharedMem: Remote SMEM access
    CTA0->>SharedMem: Compute prefix sum,<br/>pivot bucket
    CTA0->>CTA0: Wait at cluster<br/>barrier
    CTAi->>CTAi: Wait at cluster<br/>barrier
    end

    rect rgba(100, 150, 0, 0.5)
    note over CTA0,CTAi: Element Collection<br/>(Pass 1 & 2)
    CTA0->>GlobalMem: Batch atomicAdd for<br/>elements > pivot
    CTA0->>CTA0: Arrive at cluster<br/>barrier
    CTAi->>GlobalMem: atomicAdd for<br/>elements == pivot<br/>(pos < top_k guard)
    CTAi->>CTAi: Arrive at cluster<br/>barrier
    end

    CTA0->>GlobalMem: Release store<br/>output counter = 0
    CTA0->>Host: Kernel complete
    Host->>GlobalMem: Read output_indices,<br/>output_values
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 87.10% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title directly describes the main change: adding a cluster-accelerated variant of the single-pass multi-CTA radix top-k kernel, which is the core new functionality across the file additions.
Description check ✅ Passed PR description is comprehensive with clear technical details, performance benchmarks, test coverage, and checklist completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@limin2021 limin2021 changed the title Add cute dsl single pass multi cta cluster topk [TRTLLM-10407][perf] Add cute dsl single pass multi cta cluster topk Mar 19, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py`:
- Around line 169-171: The current branch in deep_ep.py that does a pre-quant
dispatch when supports_post_quant_dispatch() is True but hidden_states_sf is
None can silently mis-handle already-quantized inputs; change this to fail-fast:
in the method containing the check (the block with "if not
self.supports_post_quant_dispatch() or hidden_states_sf is None") raise an
explicit error when self.supports_post_quant_dispatch() is True but
hidden_states_sf is None, with a clear message indicating post-quant dispatch
was expected but scale factors are missing; alternatively (if you prefer
caller-side guards) update configurable_moe.py to assert x_sf is not None before
invoking post-quant paths (references: supports_post_quant_dispatch(),
hidden_states_sf, quantize_input()/x_sf in configurable_moe.py) so we do not
silently fall back to pre-quant kernels.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2050160e-b4ba-4265-86da-700553e9bbb1

📥 Commits

Reviewing files that changed from the base of the PR and between 2fb5e80 and 4453fc7.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/single_pass_multi_cta_radix_topk_cluster.py
  • tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py
  • tests/unittest/_torch/thop/parallel/test_indexer_topk.py
👮 Files not reviewed due to content moderation or server errors (3)
  • tests/unittest/_torch/thop/parallel/test_indexer_topk.py
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/single_pass_multi_cta_radix_topk_cluster.py

Comment thread tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py

return output_indices_torch, output_values

class CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner and CuteDSLTopKDecodeSinglePassMultiCTARunner are nearly identical, can we extract a get_topk_state_size to distinguish between DISTRIBUTED_TOPK_STATE_SIZE and CLUSTER_TOPK_STATE_SIZE and reuse all the rest?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse code between this file and single_pass_multi_cta_radix_topk.py? fence_acq_rel_gpu/_st_release_gpu/st_release_gpu are identical, SinglePassMultiCTARadixTopKKernel's __init__/to_ordered/from_ordered/load_chunk_to_smem/_radix_round_single_cta/compute_local_gt_count/collect_output_single_cta/prefix_sum_and_find_threshold are almost identical.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. done.

Comment on lines +627 to +648
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()

# 3. Merge histograms via DSMEM (local_histogram -> prefix_buf)
self.merge_histogram_dsmem(local_histogram, prefix_buf, tidx)

# 4. Prefix sum and find threshold bucket on merged histogram
self.prefix_sum_and_find_threshold(
prefix_buf, prefix_buf, s_scalars, remaining_k, s_warp_sums, tidx
)

# Update prefix and remaining_k
found_bucket = s_scalars[0]
found_remaining_k = s_scalars[1]

if cutlass.const_expr(self.dtype == cutlass.Float32):
prefix = prefix | cutlass.Uint32(cutlass.Uint32(found_bucket) << cutlass.Uint32(shift))
else:
prefix = prefix | cutlass.Uint16(cutlass.Uint16(found_bucket) << cutlass.Uint16(shift))
remaining_k = found_remaining_k

return prefix, remaining_k
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cluster barrier between consecutive radix rounds — DSMEM read/write race

Within _radix_round_cluster, the sequence is:

  1. Build local histogram in SMEM
  2. cluster_arrive_relaxed + cluster_wait (line 627-628)
  3. Merge all peers' histograms via DSMEM reads into prefix_buf
  4. Prefix sum on prefix_buf
  5. Return

There is no cluster barrier between step 3 of round N and step 1 of round N+1. A fast CTA can clear its local_histogram for the new round while a slow CTA in the same cluster is still reading the fast CTA's local_histogram via DSMEM in step 3 — producing corrupted histogram merges and silently wrong top-k results.

Suggested fix: Add a cluster barrier before the return:

# At the end of _radix_round_cluster, before return:
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()
return prefix, remaining_k

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. done.

Comment on lines +3819 to +3820
cs = 8192
while cs <= max_chunk:
Copy link
Copy Markdown
Collaborator

@yuxianq yuxianq Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctas_per_group can exceed hardware max cluster size — data silently not processed

_get_chunk_config returns ctas_per_group = ceil(num_cols / chunk_size), which can exceed the hardware max cluster size (e.g., 16 on B200). The kernel constructor silently clamps: self.ctas_per_group = min(ctas_per_group, hw_max). But chunk_size was computed assuming the unclamped CTA count. With clamped CTAs, ctas_per_group * chunk_size < num_cols, so chunks beyond the cluster boundary are never loaded — those elements are silently excluded from the top-k.

For example: 256K vocab with chunk_size=8192 and hw_max=16 → only 16×8192 = 128K elements processed.

Suggested fix: Clamp ctas_per_group in _get_chunk_config before returning, then recompute chunk_size = ceil(num_cols / clamped_ctas):

max_cluster = _query_max_cluster_size()
ctas_per_group = min(ctas_per_group, max_cluster)
chunk_size = math.ceil(num_cols / ctas_per_group)
# re-align chunk_size to vec_size
chunk_size = ((chunk_size + vec_size - 1) // vec_size) * vec_size

Or just add assertion that ctas_per_group <= hw_max.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already handled in the cluster runner's _get_chunk_config override (line 3689-3719). When ctas_per_group > hw_max_cluster, we recompute chunk_size = ceil(num_cols / hw_max_cluster),
re-align to vec_size, and recalculate ctas_per_group. If the resulting chunk_size exceeds max shared memory capacity, we return None to fall back to the non-cluster runner. There's also an early
guard in forward() (line 3743-3752) that falls back when num_cols > max_chunk * hw_max_cluster.

from .filtered_top_k_varlen_util import float_as_uint32, half_as_ushort


def _query_max_cluster_size() -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_query_max_cluster_size() called on every kernel instantiation — should be cached

This involves CUDA driver API calls (cuOccupancyMaxPotentialClusterSize) on every cache miss. The result is device-invariant during the process lifetime.

Suggested fix: Cache at module level with @functools.lru_cache(maxsize=1).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

…p-k kernel

Remove the per-row entry cluster barrier (cluster_arrive + cluster_wait)
in the multi-CTA cluster radix top-k kernel. The previous row's exit
barrier already synchronizes all CTAs, and the round-0 cluster barrier
guarantees visibility. Move output_counter reset to after the exit
barrier and add defensive init before the persistent loop.

Also clamp cluster size to hardware max and handle graceful fallback
when problem size exceeds cluster kernel capacity.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ulti-CTA runner

ClusterRunner now inherits from SinglePassMultiCTARunner, overriding
only _get_chunk_config (hw cluster size clamping) and forward (fallback
for unsupported problem sizes). Base class parameterized via
cls._kernel_class and cls._state_size.

Also fixes _compute_max_chunk to use get_smem_capacity_in_bytes()
instead of hardcoded 227KB.

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…matting

Rename local_histogram -> histogram_input in prefix_sum_and_find_threshold
to reflect that cluster mode passes prefix_buf while single-CTA passes
local_histogram. Update cluster file docstring for inheritance design.

Signed-off-by:
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…und race

A fast CTA could clear its local_histogram for the next radix round
while a slow CTA was still reading it via DSMEM in merge_histogram_dsmem,
producing corrupted histogram merges. Add a cluster barrier at the end
of _radix_round_cluster to ensure all DSMEM reads complete before any
CTA proceeds to the next round.

Signed-off-by:
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The CUDA driver call is device-invariant during process lifetime but was
invoked on every kernel instantiation. Cache the result with lru_cache.

Signed-off-by:
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.