[TRTLLM-10407][perf] Add cute dsl single pass multi cta cluster topk#12339
[TRTLLM-10407][perf] Add cute dsl single pass multi cta cluster topk#12339limin2021 wants to merge 348 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…mand help to work. (NVIDIA#11722) Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…ing hang with asymmetric PP/TP (NVIDIA#11789) Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…1779) Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
full:DGX_H100/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py::test_sharding[1-1] SKIP (https://nvbugs/5936322) is duplicated by unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py::test_sharding[1-1] SKIP (https://nvbugs/5875203). Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
…p-k varlen Add dynamic multi-CTA scheduling mode that assigns CTAs proportionally to each row's actual sequence length, avoiding wasted work on short rows. Key changes: - ComputeDynamicCTAOffsets: CuTE DSL kernel computing row_cta_offsets and row_output_offsets via parallel CTA counting + sequential prefix sum - FilteredTopKKernelVarlenDecode: 1D grid with binary search task mapping when enable_dynamic_multi_cta=True; varlen merge input support - CuteDSLTopKDecodeMultiCTARunner: dynamic=True/False mode with upper-bound grid allocation to avoid DtoH sync - Unit test updated for dynamic mode coverage Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…VIDIA#12002) Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
…c multi-CTA top-k Eliminate the separate ComputeDynamicCTAOffsets kernel by computing the prefix sum of per-row CTA counts directly in shared memory within the main filtered_topk_kernel. This reduces the dynamic multi-CTA path from 3 CUDA kernel launches to 2. Key changes: - filtered_topk_kernel: add in-kernel block_prefix_sum to build s_row_cta_offsets in shared memory, replacing global memory lookups - Remove row_cta_offsets and row_output_offsets from kernel signatures - Remove ComputeDynamicCTAOffsets kernel compilation and invocation - Add dynamic parameter to cute_dsl_indexer_topk_decode for easy static/dynamic comparison Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…id for dynamic multi-CTA top-k Switch dynamic multi-CTA from 1D grid + shared-memory prefix sum + binary search to a simple 2D grid (num_rows, num_ctas_per_row) with per-CTA early exit. This eliminates ~80 lines of prefix-sum/binary-search logic, removes the 512-row limit, reduces shared memory usage, and simplifies the host side by reusing the merge buffer. Also updates auto-select heuristics (bf16/fp16 threshold to 131072, SM utilization check to <25%) and makes dynamic=True the default. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
…lds (NVIDIA#11954) Signed-off-by: Abby Wei <mengzew@nvidia.com>
…with first-task direct assignment Improve the persistent dynamic scheduling by assigning the first task directly via block index (avoiding an unnecessary atomic), and expose load_balance parameter through the custom op API with test coverage. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…#11957) Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
… mPrevBlock/mNextBlocks with lookup-node pointers. (NVIDIA#11919) Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
… add overview doc (NVIDIA#11291) Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
…lclose_to_hf failure (NVIDIA#10191) Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
…k kernel Add FlashInfer-style fused multi-CTA distributed radix top-k kernel that cooperatively finds the global pivot via multi-round radix select with global histogram merging. Single kernel launch, no intermediate buffer, no merge kernel. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
… on multinode failure (NVIDIA#11905) Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
…#11907) Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…efault t… (NVIDIA#12007) Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR introduces cluster-accelerated single-pass multi-CTA radix top-k kernel for Blackwell GPUs, adding a new kernel implementation with a corresponding custom ops runner class and test coverage. A minor fix updates post-quant dispatch logic in DeepEP to check scale factor availability. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host/Launcher
participant CTA0 as CTA 0<br/>(Group Lead)
participant CTAi as CTAs 1..N<br/>(In Cluster)
participant GlobalMem as Global Memory
participant SharedMem as Shared Memory
participant DSMem as DSMEM<br/>(Cluster-wide)
Host->>GlobalMem: Load input_values, seqlen
Host->>Host: Compute num_groups, total_ctas
Host->>CTA0: Launch cluster kernel<br/>(ctas_per_group > 1)
rect rgba(0, 100, 200, 0.5)
note over CTA0,CTAi: Per-Row Radix Round (2-4 rounds)
CTA0->>SharedMem: Build local 256-bin<br/>histogram
CTAi->>SharedMem: Build local 256-bin<br/>histogram
CTA0->>CTA0: Arrive at cluster<br/>barrier
CTAi->>CTAi: Arrive at cluster<br/>barrier
CTA0->>DSMem: Merge CTAs' histograms<br/>via DSMEM reads
DSMem->>SharedMem: Remote SMEM access
CTA0->>SharedMem: Compute prefix sum,<br/>pivot bucket
CTA0->>CTA0: Wait at cluster<br/>barrier
CTAi->>CTAi: Wait at cluster<br/>barrier
end
rect rgba(100, 150, 0, 0.5)
note over CTA0,CTAi: Element Collection<br/>(Pass 1 & 2)
CTA0->>GlobalMem: Batch atomicAdd for<br/>elements > pivot
CTA0->>CTA0: Arrive at cluster<br/>barrier
CTAi->>GlobalMem: atomicAdd for<br/>elements == pivot<br/>(pos < top_k guard)
CTAi->>CTAi: Arrive at cluster<br/>barrier
end
CTA0->>GlobalMem: Release store<br/>output counter = 0
CTA0->>Host: Kernel complete
Host->>GlobalMem: Read output_indices,<br/>output_values
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py`:
- Around line 169-171: The current branch in deep_ep.py that does a pre-quant
dispatch when supports_post_quant_dispatch() is True but hidden_states_sf is
None can silently mis-handle already-quantized inputs; change this to fail-fast:
in the method containing the check (the block with "if not
self.supports_post_quant_dispatch() or hidden_states_sf is None") raise an
explicit error when self.supports_post_quant_dispatch() is True but
hidden_states_sf is None, with a clear message indicating post-quant dispatch
was expected but scale factors are missing; alternatively (if you prefer
caller-side guards) update configurable_moe.py to assert x_sf is not None before
invoking post-quant paths (references: supports_post_quant_dispatch(),
hidden_states_sf, quantize_input()/x_sf in configurable_moe.py) so we do not
silently fall back to pre-quant kernels.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2050160e-b4ba-4265-86da-700553e9bbb1
📒 Files selected for processing (4)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/single_pass_multi_cta_radix_topk_cluster.pytensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.pytests/unittest/_torch/thop/parallel/test_indexer_topk.py
👮 Files not reviewed due to content moderation or server errors (3)
- tests/unittest/_torch/thop/parallel/test_indexer_topk.py
- tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
- tensorrt_llm/_torch/cute_dsl_kernels/blackwell/top_k/single_pass_multi_cta_radix_topk_cluster.py
|
|
||
| return output_indices_torch, output_values | ||
|
|
||
| class CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner: |
There was a problem hiding this comment.
CuteDSLTopKDecodeSinglePassMultiCTAClusterRunner and CuteDSLTopKDecodeSinglePassMultiCTARunner are nearly identical, can we extract a get_topk_state_size to distinguish between DISTRIBUTED_TOPK_STATE_SIZE and CLUSTER_TOPK_STATE_SIZE and reuse all the rest?
There was a problem hiding this comment.
Can we reuse code between this file and single_pass_multi_cta_radix_topk.py? fence_acq_rel_gpu/_st_release_gpu/st_release_gpu are identical, SinglePassMultiCTARadixTopKKernel's __init__/to_ordered/from_ordered/load_chunk_to_smem/_radix_round_single_cta/compute_local_gt_count/collect_output_single_cta/prefix_sum_and_find_threshold are almost identical.
There was a problem hiding this comment.
Good catch. done.
| cute.arch.cluster_arrive_relaxed() | ||
| cute.arch.cluster_wait() | ||
|
|
||
| # 3. Merge histograms via DSMEM (local_histogram -> prefix_buf) | ||
| self.merge_histogram_dsmem(local_histogram, prefix_buf, tidx) | ||
|
|
||
| # 4. Prefix sum and find threshold bucket on merged histogram | ||
| self.prefix_sum_and_find_threshold( | ||
| prefix_buf, prefix_buf, s_scalars, remaining_k, s_warp_sums, tidx | ||
| ) | ||
|
|
||
| # Update prefix and remaining_k | ||
| found_bucket = s_scalars[0] | ||
| found_remaining_k = s_scalars[1] | ||
|
|
||
| if cutlass.const_expr(self.dtype == cutlass.Float32): | ||
| prefix = prefix | cutlass.Uint32(cutlass.Uint32(found_bucket) << cutlass.Uint32(shift)) | ||
| else: | ||
| prefix = prefix | cutlass.Uint16(cutlass.Uint16(found_bucket) << cutlass.Uint16(shift)) | ||
| remaining_k = found_remaining_k | ||
|
|
||
| return prefix, remaining_k |
There was a problem hiding this comment.
Missing cluster barrier between consecutive radix rounds — DSMEM read/write race
Within _radix_round_cluster, the sequence is:
- Build local histogram in SMEM
cluster_arrive_relaxed+cluster_wait(line 627-628)- Merge all peers' histograms via DSMEM reads into
prefix_buf - Prefix sum on
prefix_buf - Return
There is no cluster barrier between step 3 of round N and step 1 of round N+1. A fast CTA can clear its local_histogram for the new round while a slow CTA in the same cluster is still reading the fast CTA's local_histogram via DSMEM in step 3 — producing corrupted histogram merges and silently wrong top-k results.
Suggested fix: Add a cluster barrier before the return:
# At the end of _radix_round_cluster, before return:
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()
return prefix, remaining_kThere was a problem hiding this comment.
Good catch. done.
| cs = 8192 | ||
| while cs <= max_chunk: |
There was a problem hiding this comment.
ctas_per_group can exceed hardware max cluster size — data silently not processed
_get_chunk_config returns ctas_per_group = ceil(num_cols / chunk_size), which can exceed the hardware max cluster size (e.g., 16 on B200). The kernel constructor silently clamps: self.ctas_per_group = min(ctas_per_group, hw_max). But chunk_size was computed assuming the unclamped CTA count. With clamped CTAs, ctas_per_group * chunk_size < num_cols, so chunks beyond the cluster boundary are never loaded — those elements are silently excluded from the top-k.
For example: 256K vocab with chunk_size=8192 and hw_max=16 → only 16×8192 = 128K elements processed.
Suggested fix: Clamp ctas_per_group in _get_chunk_config before returning, then recompute chunk_size = ceil(num_cols / clamped_ctas):
max_cluster = _query_max_cluster_size()
ctas_per_group = min(ctas_per_group, max_cluster)
chunk_size = math.ceil(num_cols / ctas_per_group)
# re-align chunk_size to vec_size
chunk_size = ((chunk_size + vec_size - 1) // vec_size) * vec_sizeOr just add assertion that ctas_per_group <= hw_max.
There was a problem hiding this comment.
This is already handled in the cluster runner's _get_chunk_config override (line 3689-3719). When ctas_per_group > hw_max_cluster, we recompute chunk_size = ceil(num_cols / hw_max_cluster),
re-align to vec_size, and recalculate ctas_per_group. If the resulting chunk_size exceeds max shared memory capacity, we return None to fall back to the non-cluster runner. There's also an early
guard in forward() (line 3743-3752) that falls back when num_cols > max_chunk * hw_max_cluster.
| from .filtered_top_k_varlen_util import float_as_uint32, half_as_ushort | ||
|
|
||
|
|
||
| def _query_max_cluster_size() -> int: |
There was a problem hiding this comment.
_query_max_cluster_size() called on every kernel instantiation — should be cached
This involves CUDA driver API calls (cuOccupancyMaxPotentialClusterSize) on every cache miss. The result is device-invariant during the process lifetime.
Suggested fix: Cache at module level with @functools.lru_cache(maxsize=1).
…p-k kernel Remove the per-row entry cluster barrier (cluster_arrive + cluster_wait) in the multi-CTA cluster radix top-k kernel. The previous row's exit barrier already synchronizes all CTAs, and the round-0 cluster barrier guarantees visibility. Move output_counter reset to after the exit barrier and add defensive init before the persistent loop. Also clamp cluster size to hardware max and handle graceful fallback when problem size exceeds cluster kernel capacity. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…ulti-CTA runner ClusterRunner now inherits from SinglePassMultiCTARunner, overriding only _get_chunk_config (hw cluster size clamping) and forward (fallback for unsupported problem sizes). Base class parameterized via cls._kernel_class and cls._state_size. Also fixes _compute_max_chunk to use get_smem_capacity_in_bytes() instead of hardcoded 227KB. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…matting Rename local_histogram -> histogram_input in prefix_sum_and_find_threshold to reflect that cluster mode passes prefix_buf while single-CTA passes local_histogram. Update cluster file docstring for inheritance design. Signed-off-by: Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
…und race A fast CTA could clear its local_histogram for the next radix round while a slow CTA was still reading it via DSMEM in merge_histogram_dsmem, producing corrupted histogram merges. Add a cluster barrier at the end of _radix_round_cluster to ensure all DSMEM reads complete before any CTA proceeds to the next round. Signed-off-by: Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
The CUDA driver call is device-invariant during process lifetime but was invoked on every kernel instantiation. Cache the result with lru_cache. Signed-off-by: Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
7d23b85 to
fb0a63c
Compare
Summary by CodeRabbit
New Features
Bug Fixes
Fix the following bugs when running trtllm-eval with TRTLLM_FORCE_COMM_METHOD=DEEPEP
Description
Add a cluster-accelerated single-pass multi-CTA radix top-k kernel for Blackwell (SM100+).
Replaces global memory atomics + arrival counter polling with:
cluster_arrive_relaxed+cluster_wait) for inter-CTA synchronizationThis eliminates the triple-buffered global histogram, the arrival counter, and all GPU-scope acquire/release PTX for barriers. Only the output counter (1 int32 in GMEM) is retained for
atomicAddduring output collection.Key changes
single_pass_multi_cta_radix_topk_cluster.pyCuteDSLTopKDecodeSinglePassMultiCTAClusterRunnerwith SM-aware heuristics and hardware max cluster size clampingPerf
Comparison of different multi-CTA kernels on fixed-length input (dtype=float32, top_k=2048, B200).
Kernels compared:
cl/dist < 1.0means cluster is faster than distributed.batch=1
batch=4
batch=8
batch=16
batch=32
batch=64
batch=128
batch=256
Key observations
Test Coverage
pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py -k "test_cute_dsl_topk_decode_single_pass_multi_cta_cluster"— 180 tests passed (3 dtypes x 3 vocab sizes x 4 batch/next_n combos x 5 ctas_per_group configs)PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.