Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import math
from enum import Enum
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
Expand All @@ -7,7 +10,7 @@
from tensorrt_llm._torch.attention_backend.interface import MLAParams, PositionalEmbeddingParams
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttention, TrtllmAttentionMetadata
from tensorrt_llm._torch.modules.linear import Linear # noqa: E402 (avoid cycle)
from tensorrt_llm._torch.modules.multi_stream_utils import maybe_execute_in_parallel
from tensorrt_llm._torch.modules.multi_stream_utils import do_multi_stream
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
from tensorrt_llm._torch.utils import maybe_compile
from tensorrt_llm._utils import prefer_pinned
Expand Down Expand Up @@ -951,6 +954,9 @@ def __init__(
is_indexer=True,
rotate_activation=HAS_FAST_HADAMARD,
)
self.indexer_start_event = torch.cuda.Event()
self.weights_proj_event = torch.cuda.Event()
self.k_cache_update_event = torch.cuda.Event()

def post_load_weights(self):
# V4 does not use the V3 fused fp32 wk+weights_proj GEMM, and the
Expand Down Expand Up @@ -979,6 +985,101 @@ def _qk_projection_and_rope(self, qr: torch.Tensor, position_ids: torch.Tensor):
)
return q

def _quantize_q(self, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Rotate + quantize (layout matches compressor K: [nope|pe]).
q = rotate_activation(q)
q = q.view(-1, self.head_dim)
q_fp8, q_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(
q, use_ue8m0=self.scale_fmt == "ue8m0"
)
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
q_scale = q_scale.view(-1, self.n_heads, 1)
return q_fp8, q_scale

def _update_k_cache_if_needed(
self,
k_fp8: Optional[torch.Tensor],
k_scale: Optional[torch.Tensor],
metadata: DeepseekV4TrtllmAttentionMetadata,
) -> None:
if k_fp8 is None:
return

assert k_scale is not None, "FP8 blockwise indexer cache update requires scale tensor"
self._update_k_cache(k_fp8, k_scale, metadata)

def _run_overlapped_indexer_prepare(
self,
qr: torch.Tensor,
hidden_states: torch.Tensor,
metadata: DeepseekV4TrtllmAttentionMetadata,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]:
"""Prepare indexer inputs by splitting independent work across two streams.

The current stream owns the Q path. The auxiliary stream starts from
the recorded launch point and owns weights projection, compressor, and
the K-cache update.

Timeline:
current stream:
record indexer_start_event
q_proj + RoPE -> quant_q
wait weights_proj_event -> weight_scale
wait k_cache_update_event -> return

aux_stream:
wait indexer_start_event
weights_proj -> record weights_proj_event
compressor -> update_k_cache -> record k_cache_update_event

Dependency graph:
q_proj + RoPE -> quant_q -- q_scale --.
v
weights_proj --------------------> weight_scale
compressor -> update_k_cache ----> final wait
"""
self.indexer_start_event.record()

q = self._qk_projection_and_rope(qr, position_ids)

with torch.cuda.stream(self.aux_stream):
self.indexer_start_event.wait()

weights = self.weights_proj(hidden_states)
self.weights_proj_event.record()

k_fp8, k_scale = self.compressor(hidden_states, metadata)
self._update_k_cache_if_needed(k_fp8, k_scale, metadata)
self.k_cache_update_event.record()

q_fp8, q_scale = self._quantize_q(q)

self.weights_proj_event.wait()
weights = self._weight_scale(weights, q_scale)

self.k_cache_update_event.wait()
return q_fp8, k_fp8, k_scale, weights

def _run_serial_indexer_prepare(
self,
qr: torch.Tensor,
hidden_states: torch.Tensor,
metadata: DeepseekV4TrtllmAttentionMetadata,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]:
q = self._qk_projection_and_rope(qr, position_ids)

q_fp8, q_scale = self._quantize_q(q)

weights = self.weights_proj(hidden_states)

weights = self._weight_scale(weights, q_scale)

k_fp8, k_scale = self.compressor(hidden_states, metadata)
self._update_k_cache_if_needed(k_fp8, k_scale, metadata)
return q_fp8, k_fp8, k_scale, weights

def forward(
self,
qr: torch.Tensor,
Expand All @@ -993,36 +1094,27 @@ def forward(
"cache layout, cache update, and BMM paths before end-to-end indexer "
"execution can use it."
)
# compress k
k_fp8, k_scale = self.compressor(hidden_states, metadata)

# multi-stream q proj/rope and weights proj
q, weights = maybe_execute_in_parallel(
lambda: self._qk_projection_and_rope(qr, position_ids),
lambda: self.weights_proj(hidden_states),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

# Rotate + quantize (layout matches compressor K: [nope|pe])
q = rotate_activation(q)
q = q.view(-1, self.head_dim)
q_fp8, q_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(
q, use_ue8m0=self.scale_fmt == "ue8m0"
)
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
q_scale = q_scale.view(-1, self.n_heads, 1)

# weights scale
weights = self._weight_scale(weights, q_scale)
if do_multi_stream() and self.aux_stream is not None:
q_fp8, k_fp8, k_scale, weights = self._run_overlapped_indexer_prepare(
qr, hidden_states, metadata, position_ids
)
else:
q_fp8, k_fp8, k_scale, weights = self._run_serial_indexer_prepare(
qr, hidden_states, metadata, position_ids
)

# If there are no compressed tokens, return an topk indices buffer with all -1s in the tensor.
if k_fp8 is None:
topk_indices = metadata.empty_topk_indices_buffer[: hidden_states.shape[0]]
else:
topk_indices = self.sparse_attn_indexer(
metadata, hidden_states, q_fp8, k_fp8, k_scale, weights
metadata,
hidden_states,
q_fp8,
k_fp8,
k_scale,
weights,
update_k_cache=False,
)
return topk_indices

Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dense Sparse Attention (DSA) backend for TRT-LLM with indexer-based TopK selection."""
import math
from dataclasses import dataclass
Expand Down Expand Up @@ -1817,13 +1819,15 @@ def sparse_attn_indexer(
k_scale: torch.Tensor,
weights: torch.Tensor,
use_custom_topk: bool = True,
update_k_cache: bool = True,
) -> torch.Tensor:
"""Run the indexer TopK kernel for both prefill and decode phases."""
assert metadata.kv_cache_manager is None or \
metadata.kv_cache_manager.quant_block_size == 128, \
"Only support quant_block_size = 128 for now"
# Update the indexer k cache before prefill chunks gather from it.
self._update_k_cache(k_fp8, k_scale, metadata)
if update_k_cache:
self._update_k_cache(k_fp8, k_scale, metadata)

num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
Expand Down
113 changes: 72 additions & 41 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
is_torch_compiling, maybe_compiled_cat,
maybe_compiled_copy_)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .multi_stream_utils import do_multi_stream, maybe_execute_in_parallel
from .rms_norm import RMSNorm
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding

Expand Down Expand Up @@ -1427,6 +1427,19 @@ def yarn_get_mscale(scale=1, mscale=1):
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
q_scaling = 1.0 / (mscale * mscale)

self.has_dsv4_indexer = (
self.is_deepseek_v4 and layer_idx is not None
and config.sparse_attention_config is not None
and config.sparse_attention_config.compress_ratios[layer_idx] == 4)
self.indexer_stream = None
self.indexer_aux_stream = None
if self.has_dsv4_indexer and aux_stream is not None:
self.indexer_stream = torch.cuda.Stream(device=aux_stream.device)
self.indexer_aux_stream = torch.cuda.Stream(
device=aux_stream.device)
mqa_aux_stream = (self.indexer_aux_stream if self.indexer_aux_stream
is not None else aux_stream)

self.mqa = create_attention(
config.attn_backend,
self.layer_idx,
Expand All @@ -1448,14 +1461,17 @@ def yarn_get_mscale(scale=1, mscale=1):
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
dtype=dtype,
aux_stream=aux_stream,
aux_stream=mqa_aux_stream,
rope_append=not self.is_deepseek_v4,
)

self.softmax_scale = 1.0 / (math.sqrt(self.qk_head_dim) * q_scaling)

self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.dsv4_overlap_start_event = torch.cuda.Event()
self.dsv4_compressor_event = torch.cuda.Event()
self.dsv4_indexer_event = torch.cuda.Event()

self.rope_fusion = self.mqa.support_fused_rope()
self.rotary_emb = None
Expand Down Expand Up @@ -2113,57 +2129,72 @@ def forward_impl_with_deepseek_v4(self,
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)

# TRTLLM_MLA_EXTRA_OVERLAP=1 reorders the V4 attention prologue so the
# outer compressor (writes its own KV-cache slot, only reads
# hidden_states) executes on the auxiliary CUDA stream concurrently
# with q_b_proj + q_b_layernorm on the default stream. Both branches
# are entirely independent (no shared inputs and disjoint writes),
# so this is a pure dependency-aware reorder. Falls back to the
# serial schedule when the env-var is unset, when the aux stream is
# unavailable, or when multi-stream mode is off.
# outer compressor and the ratio-4 indexer can execute concurrently
# with q_b_proj + q_b_layernorm. The indexer is launched on a
# dedicated stream and still uses a different aux stream for its
# internal q-proj/weights-proj split.
_v4_extra_overlap = (os.environ.get("TRTLLM_MLA_EXTRA_OVERLAP", "0")
== "1" and self.compressor is not None
and self.aux_stream is not None)

if _v4_extra_overlap:

def _q_branch():
q_proj = self.q_b_proj(q)
# Per-head RMS: view as [N*n_heads, head_dim] so RMSNorm
# reduces per-head.
return self.q_b_layernorm(q_proj.view(
-1, self.qk_head_dim)).view_as(q_proj)
def _q_branch():
q_proj = self.q_b_proj(q)
# Per-head RMS: view as [N*n_heads, head_dim] so RMSNorm
# reduces per-head.
return self.q_b_layernorm(q_proj.view(
-1, self.qk_head_dim)).view_as(q_proj)

def _compressor_branch():
self.compressor(hidden_states, attn_metadata)
return None
def _compressor_branch():
self.compressor(hidden_states, attn_metadata)
return None

q, _ = maybe_execute_in_parallel(
_q_branch,
_compressor_branch,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
def _indexer_branch():
return self.indexer(
qr,
hidden_states,
attn_metadata,
position_ids,
)

topk_indices = None
indexer_ran = False
if _v4_extra_overlap:
use_indexer_overlap = (do_multi_stream()
and self.indexer is not None
and self.indexer_stream is not None)
if use_indexer_overlap:
self.dsv4_overlap_start_event.record()

with torch.cuda.stream(self.aux_stream):
self.dsv4_overlap_start_event.wait()
_compressor_branch()
self.dsv4_compressor_event.record()

with torch.cuda.stream(self.indexer_stream):
self.dsv4_overlap_start_event.wait()
topk_indices = _indexer_branch()
indexer_ran = True
self.dsv4_indexer_event.record()

q = _q_branch()
self.dsv4_compressor_event.wait()
self.dsv4_indexer_event.wait()
else:
q, _ = maybe_execute_in_parallel(
_q_branch,
_compressor_branch,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
q = self.q_b_proj(q)
# Per-head RMS: view as [N*n_heads, head_dim] so RMSNorm reduces per-head.
q = self.q_b_layernorm(q.view(-1, self.qk_head_dim)).view_as(q)
Comment thread
liji-nv marked this conversation as resolved.
q = _q_branch()
if self.compressor is not None:
self.compressor(hidden_states, attn_metadata)

# Indexer is independent of both q_b_proj and the compressor's KV-cache
# write, so it runs after either schedule. Kept serial because it
# internally reuses self.aux_stream for its own multi-stream q-proj ||
# weights-proj split; running it concurrently with q_b_proj would
# create a stream-aliasing hazard.
topk_indices = None
if self.indexer is not None:
topk_indices = self.indexer(
qr,
hidden_states,
attn_metadata,
position_ids,
)
if not indexer_ran:
topk_indices = _indexer_branch()

assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
Expand Down
14 changes: 13 additions & 1 deletion tests/unittest/_torch/modeling/test_modeling_deepseekv4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import inspect
import json
import struct
import textwrap
import weakref
from copy import deepcopy

Expand Down Expand Up @@ -91,6 +93,14 @@
}


def _source_calls(source):
return {
ast.unparse(node)
for node in ast.walk(ast.parse(textwrap.dedent(source)))
if isinstance(node, ast.Call)
}


def _write_safetensors_header(path, tensor_name, dtype, shape):
header = {
tensor_name: {
Expand Down Expand Up @@ -326,7 +336,9 @@ def test_deepseek_v4_mla_q_b_layernorm_init_and_forward_shape():
assert "kv_a_layernorm_hidden_size = (" in init_src
assert "self.kv_lora_rank + self.qk_rope_head_dim" in init_src
assert "self.kv_a_layernorm = RMSNorm(hidden_size=kv_a_layernorm_hidden_size" in init_src
assert "self.q_b_layernorm(q.view(-1, self.qk_head_dim)).view_as(q)" in forward_src
assert "self.q_b_layernorm(q_proj.view(-1, self.qk_head_dim)).view_as(q_proj)" in _source_calls(
forward_src
)


def test_deepseek_v4_compressor_rotate_and_indexer_rope_contracts():
Expand Down
Loading