Skip to content

[None][perf] Speed up mHC fused_hc with 4 small kernel-side optimizations#13892

Draft
mingyangHao wants to merge 4 commits intoNVIDIA:feat/deepseek_v4from
mingyangHao:mingyangh/mhc-v4pro-perf
Draft

[None][perf] Speed up mHC fused_hc with 4 small kernel-side optimizations#13892
mingyangHao wants to merge 4 commits intoNVIDIA:feat/deepseek_v4from
mingyangHao:mingyangh/mhc-v4pro-perf

Conversation

@mingyangHao
Copy link
Copy Markdown
Collaborator

@mingyangHao mingyangHao commented May 8, 2026

@coderabbitai summary

Description

Four small kernel-side performance changes layered on top of the existing tcgen05 fused_hc infrastructure (#13771 added hidden=7168 support; this PR builds on that). Each change is a separate commit and is gated by if constexpr / static_assert so behavior at hidden=4096 / KS unchanged from prior is preserved.

P0 — Path D KS=112 enable at hidden=7168 (commit 1)

  • Phase 4 layer_input loop required HIDDEN % (WARPS_PER_TOK * 32 * 8) == 0. At KS=112 / HIDDEN=7168 the team stride is 2048 and 7168 % 2048 = 1024 → static_assert + isSupportedFhcMmaKS<> rejected the tactic.
  • Add a scalar-vec tail covering [H_VEC_END, HIDDEN) after the vectorized main loop. Each tail thread still issues one uint4 LDG/STG so per-thread bandwidth matches; only some lanes/warps in the team idle on the last chunk.
  • Drop the team-stride alignment from isSupportedFhcMmaKS<H, KS> (only Hidden % BF16_VEC_LI and h_tiles % KS remain). Mirror in Python _fused_hc_mma_ks_supported.
  • Add KS ∈ {7, 14, 28, 56, 112} cases to pickFhc / pickFhcAllInOne outer switches. The fhc*InstanceIfSupported<H, KS> wrappers already gate per-Hidden, so new cases compile to TLLM_CHECK_WITH_INFO on hidden=4096.

P1 — TMA descriptor cache (commit 2)

  • cuTensorMapEncodeTiled is a host-side ~1-2 µs call; each fused_hc launch builds 4 descriptors → 4-8 µs CPU overhead, ~25-50% of total wall time at small M.
  • Add a thread-local unordered_map<TmaDescKey, CUtensorMap> keyed on (base ptr, gmem dims, strides, swizzle, dtype, device_id). Both tcgen05 launchers use getCachedTma2D instead of makeTma2D.
  • CUDA-graph-safe: cuTensorMapEncodeTiled records nothing into the stream; pointer stability across calls is already guaranteed by _FusedHcWorkspaceCache in mhc_cuda.py.
  • Return-by-value (CUtensorMap is a 128 B POD); no rehash invalidation. Falls through to makeTma2D on cache miss.

P2 — KS=1 direct store + skip workspace zero (commit 3)

  • At KS=1 each (m_block, n) is owned by exactly one CTA. The GEMM epilogue can write D / sqr_sum directly instead of via atomicAdd, guarded by if constexpr (kNumSplits == 1).
  • Launchers skip the fhcZeroWorkspaces kernel at KS=1 (y_acc / r_acc / done_counter are unused under direct-store; Phase 3 at KS=1 is just __threadfence_block + __syncthreads).
  • Effect at KS=1: -1 kernel launch + -25 atomicAdd ops/token. KS>1 path unchanged.

P3 — Sinkhorn reciprocal multiply (commit 4)

  • Sinkhorn row-normalize did 4 fp32 fdivs per iter sharing the same divisor rs. Replace with explicit inv_rs = 1.0f / rs + 4 fmuls. NVCC does not always CSE fdiv in non-fast-math mode.
  • B200 fp32 fdiv ~10 cycle latency vs fmul ~4 cycle → saves ~24 cycles per row-norm × 20 sinkhorn iters at V4-Pro = ~480 cycles/token.
  • Applied in mhcKernels.cu::mhcBigFuseKernel (Path B/E shared bigfuse), fused_tf32_pmap_gemm.cuh Path D Phase 4, and mhc_fused_fma.cuh Path F Phase 4 (same sinkhorn block in three places).

Numerical results

Verified at hidden=7168, V4-Pro params (n=4, sinkhorn_iters=20):

  • residual_cur: bit-identical to baseline
  • post_mix_cur max diff: 1.2e-7 (fp32 epsilon)
  • comb_mix_cur max diff: 6e-8 (fp32 epsilon)
  • layer_input max diff: 1.5e-5 (bf16 round-off)

Measured speedup

Idle B300 (sm_103), V4-Pro shape (n=4, hidden=7168, sinkhorn=20), 300 iter avg via cudaEvent:

Tactic forced (gpu_us) baseline optimized Δ
half_mma KS=2 M=4096 (autotuner winner) 171.10 165.75 -3.1%
half_mma KS=1 M=2048 216.20 212.19 -1.9%
half_mma KS=1 M=512 188.35 185.90 -1.3%
half_mma KS=56 M=64 24.52 23.87 -2.7%
all_mma KS=112 M=64 (P0) FAIL (static_assert) 21.92 newly enabled

Autotuner-driven winners at M ≤ 384 are unchanged (still half_mma KS=112). Visible production gain at M ≥ 1024 where KS=2/4 wins, driven primarily by P1+P3.

Test Coverage

  • tests/unittest/_torch/modules/test_mhc.py::test_mhc_fused_hc_backends already parameterizes over hidden_size ∈ {4096, 7168} and all 4 backends; covers the existing tactic table unchanged.
  • Locally verified (fused_all_mma, 0, 112, 0, 1) at hidden=7168, M=64 produces bit-identical residual_cur vs (fused_all_mma, 0, 56, 0, 1) reference; other outputs within fp32 / bf16 round-off.
  • TODO before un-drafting: add an explicit fixture entry forcing KS=112 in _BACKEND_TACTICS_BY_M so CI exercises the newly enabled tactic on every PR.
  • TODO: extend test_mhc_fused_hc_cuda_graph to also force a KS=1 tactic (covers P2's direct-store path under graph capture).

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…lar tail

Phase 4 layer_input loop in fused_allinone_tf32_pmap_gemm_atomic_impl
required HIDDEN to be a multiple of WARPS_PER_TOK * 32 * BF16_VEC_LI.
At KS=112 / HIDDEN=7168 the team stride is 8*32*8=2048 and 7168 % 2048
= 1024 → static_assert blocked the instantiation, and the C++ trait
isSupportedFhcMmaKS<Hidden, KS> rejected the (hidden=7168, ks=112) tuple.

Replace the assert with a vectorized main loop ending at
H_VEC_END = floor(HIDDEN/H_STRIDE)*H_STRIDE, plus a scalar-vec tail
covering [H_VEC_END, HIDDEN). Each tail thread still issues one uint4
LDG/STG (BF16_VEC_LI=8 elements), so per-thread bandwidth matches the
main loop; only the team's lane/warp coverage is shorter, which is fine
since the tail is at most H_STRIDE - BF16_VEC_LI elements wide.

Drop the (Hidden % team-stride) check from isSupportedFhcMmaKS now that
the tail handles the residue. Mirror the relaxed trait in the Python
filter (_fused_hc_mma_ks_supported in mhc_cuda.py).

Add KS ∈ {7, 14, 28, 56, 112} cases to pickFhc and pickFhcAllInOne outer
switches — these are valid divisors of HIDDEN/BLOCK_K = 112 at
hidden=7168. The fhc*InstanceIfSupported<Hidden, KS> wrappers already
guard per-Hidden, so the new cases compile to a TLLM_CHECK_WITH_INFO
on hidden=4096 (where 112 doesn't divide h_tiles=64).

Extend _FUSED_HC_HALF_MMA_KS / _FUSED_HC_ALL_MMA_KS in mhc_cuda.py to
include the new factors. _fused_hc_mma_ks_supported filters per
(hidden, ks) tuple, so 7168-only KS values silently drop out at
hidden=4096.

Numerical: Path D KS=112 at hidden=7168 produces bit-identical
residual_cur to Path D KS=56; post_mix_cur / comb_mix_cur / layer_input
within fp32 rounding noise (max 1.5e-5 bf16 round-off) vs reference.

Unblocks Path D winning M=32-64 buckets (currently capped at half_mma
because Path D could not reach KS=112).

Signed-off-by: Mingyang Hao <mingyangh@nvidia.com>
Signed-off-by: mingyangh <mingyangh@nvidia.com>
cuTensorMapEncodeTiled is a host-side ~1-2 µs call. Each fused_hc launch
builds 4 descriptors (residual_in, x_in, W, residual_cur), so the
per-call descriptor build is 4-8 µs — 25-50% of total wall time at
small M (M ≤ 64).

Add a thread-local std::unordered_map<TmaDescKey, CUtensorMap> keyed on
all parameters that determine descriptor content (base ptr, gmem dims,
strides, swizzle, dtype, device_id). Two launchers (mhcFusedHcLaunch,
mhcFusedHcAllInOneLaunch) call into getCachedTma2D() instead of
makeTma2D() — the FMA path doesn't use TMA descriptors so it's
unaffected.

Cache scope: per-host-thread (`thread_local`). Same host thread
launching to multiple CUDA streams shares one cache (descriptor content
depends only on (device, ptr, shape), not on the stream the kernel
runs on). Multi-thread callers each get their own cache (no
synchronization overhead).

Multi-GPU: device_id is part of the key so a host thread that switches
CUDA devices never reuses a descriptor across address spaces.

CUDA-graph capture: cuTensorMapEncodeTiled is a pure host function that
does not record any stream operation, so cache miss inside capture is
safe. The descriptor is passed by value as __grid_constant__; the
recorded graph node holds those bytes and replays correctly under
workspace-stable replay (already enforced by _FusedHcWorkspaceCache in
mhc_cuda.py).

Return-by-value: CUtensorMap is a 128 B POD; we copy out before any
later cache miss can rehash the underlying unordered_map.

Lifetime: entries are never explicitly evicted. Steady-state working
set is O(B-bucket × hidden) per thread, typically O(20) entries (~5 KB),
never observed to grow under inference workloads.

Falls through to makeTma2D() on cache miss, so first-call behavior is
identical to before.

Signed-off-by: Mingyang Hao <mingyangh@nvidia.com>
Signed-off-by: mingyangh <mingyangh@nvidia.com>
…pace zero

At KS=1 each (m_block, n) is owned by exactly one CTA, so the GEMM
epilogue can write D directly instead of going through atomicAdd. The
sqr_sum reduction is similar — single-CTA ownership of each token row.
This removes 24 atomicAdd-per-token (D) + 1 atomicAdd-per-token
(sqr_sum) of GMEM round-trips per launch.

Because the kernel no longer reads the previous value of D / sqr_sum at
KS=1, the launcher can skip the workspace-zero kernel entirely:
  - Path B (mhcFusedHcLaunch): skip y_acc + r_acc zero
  - Path D (mhcFusedHcAllInOneLaunch): skip y_acc + r_acc + done_counter
    zero (Phase 3 at KS=1 uses __threadfence_block + __syncthreads,
    never touches done_counter)

KS>1 unchanged: epilogue still atomicAdd, launcher still pre-zeros.

Effect: -1 kernel launch (zero kernel) + -25 atomic ops/token at KS=1.
Most visible at large M where KS=1 is the heuristic choice (M >= 4096).

Verified at hidden=7168, M=2048 / M=4096:
  KS=1 (direct, no pre-zero, garbage workspace) vs KS=2 (atomic):
    residual_cur diff = 0
    post_mix_cur max diff = 2.4e-6
    comb_mix_cur max diff = 5.8e-6
    layer_input max diff = 6.1e-5  (bf16 round-off)

Signed-off-by: Mingyang Hao <mingyangh@nvidia.com>
Signed-off-by: mingyangh <mingyangh@nvidia.com>
…ultiply

Sinkhorn row-normalize used to do 4 fp32 divisions per iteration:
  rs = sum_4(...);
  for k in 0..3:
    cm[k] = cm[k] / rs + eps

The 4 divisions all share the divisor rs, but NVCC does not always CSE
fdiv in non-fast-math mode. Replace with explicit reciprocal:
  inv_rs = 1.0f / rs;
  for k in 0..3:
    cm[k] = cm[k] * inv_rs + eps

fp32 fdiv on B200 has ~10-cycle latency vs ~4 for fmul, so each sinkhorn
iter saves ~24 cycles on the row-normalize. Column-normalize divisions
use per-k divisors (cs depends on k) so cannot be hoisted — rewritten
in *= 1/(cs+eps) form purely for syntactic consistency, compiler emits
the same PTX as /=.

Applied in three places that run identical sinkhorn:
  - mhcKernels.cu :: mhcBigFuseKernel (Path B / E / F bigfuse)
  - fused_tf32_pmap_gemm.cuh :: fused_allinone_tf32_pmap_gemm Phase 4
  - mhc_fused_fma.cuh :: fused_pmap_gemm_fma_allinone Phase 4

Effect at sinkhorn_repeat=20 (V4-Pro default): ~480 saved cycles per
token (20 iter × 24 cycles), so ~0.25 µs / token at 2 GHz. Most visible
at small M where bigfuse / Phase 4 is a larger fraction of wall time.

Numerical: comb_mix_out max diff vs upstream Path D KS=56 reference is
5.96e-8 (below fp32 epsilon ~1e-7). residual_cur and layer_input
unchanged within bf16 round-off (1e-7 / 2.4e-7).

Signed-off-by: Mingyang Hao <mingyangh@nvidia.com>
Signed-off-by: mingyangh <mingyangh@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant