Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ HuggingFace Model → LLM API → Executor (PyTorch/AutoDeploy/TensorRT)
| **Distributed execution** | Tensor/pipeline parallelism via `Mapping` class, multiple backends (MPI, Ray, RPC) |
| **Auto-discovery** | Models self-register via `automodel.py`, resolved by HF config `architectures` field |

## VisualGen

VisualGen is a vertical alongside LLM for Diffusion-Transformer (DiT)-based image/video generation
(text-to-image, text-to-video, image-to-video). It is **not** an LLM backend — it has
its own engine, args, params, and outputs — but shares ops and kernels with the
PyTorch backend where it makes sense (attention, quantization, parallelism).

Key entry points:
- Public Python API: `from tensorrt_llm import VisualGen, VisualGenArgs, VisualGenParams`.
- Serving CLI: `trtllm-serve --model <HF id> --visual_gen_args <YAML path>`.

Key files:
- `tensorrt_llm/visual_gen/`: VisualGen public Python API. **User-facing surface — before modifying anything here, pause and confirm with the user that a public API change is actually intended; do not infer it from the surrounding task.**
- `tensorrt_llm/_torch/visual_gen/`: VisualGen internal implementation. All non-user-facing code belongs here.

## Anti-Patterns / Gotchas

- **Pre-commit modifies files in-place** — if hooks fail, files are already modified. Re-stage (`git add`) and commit again.
Expand Down
16 changes: 10 additions & 6 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,20 @@ void initBindings(nb::module_& m)
nb::arg("num_heads"), nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block").none(),
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"),
nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"),
nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
nb::arg("position_embedding_type"), nb::arg("rope_dim"), nb::arg("rope_base"), nb::arg("rope_scale_type"),
nb::arg("rope_scale"), nb::arg("rope_short_m_scale"), nb::arg("rope_long_m_scale"),
nb::arg("rope_max_positions"), nb::arg("rope_original_max_positions"), nb::arg("use_paged_context_fmha"),
nb::arg("attention_input_type").none(), nb::arg("is_mla_enable"),
nb::arg("chunked_prefill_buffer_batch_size").none(), nb::arg("q_lora_rank").none(),
nb::arg("kv_lora_rank").none(), nb::arg("qk_nope_head_dim").none(), nb::arg("qk_rope_head_dim").none(),
nb::arg("v_head_dim").none(), nb::arg("rope_append").none(), nb::arg("mrope_rotary_cos_sin").none(),
nb::arg("mrope_position_deltas").none(), nb::arg("helix_tensor_params"), nb::arg("attention_chunk_size").none(),
nb::arg("softmax_stats_tensor").none(), nb::arg("spec_decoding_bool_params"),
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices").none(),
nb::arg("mrope_position_deltas").none(), nb::arg("helix_position_offsets").none(),
nb::arg("helix_is_inactive_rank").none(), nb::arg("attention_chunk_size").none(),
nb::arg("softmax_stats_tensor").none(), nb::arg("is_spec_decoding_enabled"), nb::arg("use_spec_decoding"),
nb::arg("is_spec_dec_tree"), nb::arg("spec_decoding_generation_lengths").none(),
nb::arg("spec_decoding_position_offsets_for_cpp").none(), nb::arg("spec_decoding_packed_mask").none(),
nb::arg("spec_decoding_bl_tree_mask_offset").none(), nb::arg("spec_decoding_bl_tree_mask").none(),
nb::arg("spec_bl_tree_first_sparse_mask_offset_kv").none(), nb::arg("sparse_kv_indices").none(),
nb::arg("sparse_kv_offsets").none(), nb::arg("sparse_attn_indices").none(),
nb::arg("sparse_attn_offsets").none(), nb::arg("sparse_attn_indices_block_size"),
nb::arg("num_sparse_topk") = std::nullopt, nb::arg("sparse_mla_topk_lens") = std::nullopt,
Expand Down
179 changes: 82 additions & 97 deletions cpp/tensorrt_llm/thop/attentionOp.cpp

Large diffs are not rendered by default.

20 changes: 13 additions & 7 deletions cpp/tensorrt_llm/thop/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,23 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, double const q_scaling,
int64_t const position_embedding_type, int64_t const rotary_embedding_dim, double const rotary_embedding_base,
int64_t const rotary_embedding_scale_type, std::vector<double> rotary_embedding_scales,
std::vector<int64_t> rotary_embedding_max_position_info, bool const use_paged_context_fmha,
std::optional<int64_t> attention_input_type, bool is_mla_enable,
int64_t const position_embedding_type, int64_t const rope_dim, double const rope_base,
int64_t const rope_scale_type, double const rope_scale, double const rope_short_m_scale,
double const rope_long_m_scale, int64_t const rope_max_positions, int64_t const rope_original_max_positions,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim, std::optional<bool> rope_append,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> helix_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::optional<torch::Tensor> helix_position_offsets, std::optional<torch::Tensor> helix_is_inactive_rank,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
bool const is_spec_decoding_enabled, bool const use_spec_decoding, bool const is_spec_dec_tree,
std::optional<torch::Tensor> spec_decoding_generation_lengths,
std::optional<torch::Tensor> spec_decoding_position_offsets_for_cpp,
std::optional<torch::Tensor> spec_decoding_packed_mask,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask_offset,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask,
std::optional<torch::Tensor> spec_bl_tree_first_sparse_mask_offset_kv,
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> num_sparse_topk,
Expand Down
28 changes: 21 additions & 7 deletions examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
compile_backend: torch-cudagraph
enable_chunked_prefill: true
model_factory: AutoModelForCausalLM
runtime: trtllm
kv_cache_config:
dtype: fp8
free_gpu_memory_fraction: 0.9
attn_backend: trtllm
max_batch_size: 256
max_seq_len: 16384
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 192, 256]
kv_cache_config:
dtype: fp8
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
sharding_source: ['manual']
manual_config:
head_dim: 128
tp_plan:
"q_proj": "colwise"
"k_proj": "colwise"
"v_proj": "colwise"
"o_proj": "rowwise"
"gate_proj": "colwise"
"up_proj": "colwise"
"down_proj": "rowwise"
fuse_trtllm_attn_quant_fp8:
enabled: true
fuse_fp8_gemms:
Expand All @@ -20,5 +31,8 @@ transforms:
fuse_rope_into_trtllm_attention:
enabled: true
fuse_silu_mul:
enabled: true
backend: trtllm
compile_model:
piecewise_enabled: true
mlir_elementwise_fusion:
enabled: true
12 changes: 4 additions & 8 deletions examples/auto_deploy/model_registry/configs/nano_v3.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_batch_size: 384
max_seq_len: 65536 # tunable
enable_chunked_prefill: true
attn_backend: trtllm
model_factory: AutoModelForCausalLM
skip_loading_weights: false
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
free_gpu_memory_fraction: 0.88
# tunable mamba cache dtype
# --> use float32 for accuracy and default (auto) for speed
mamba_ssm_cache_dtype: auto
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
Expand Down Expand Up @@ -42,7 +35,6 @@ transforms:
"fc1_latent_proj": "gather"
"fc2_latent_proj": "gather"
multi_stream_moe:
stage: compile
enabled: true
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
Expand All @@ -54,3 +46,7 @@ transforms:
backend: flashinfer_ssm
compile_model:
piecewise_enabled: true
fuse_nvfp4_moe:
backend: trtllm_gen
mlir_elementwise_fusion:
enabled: true
102 changes: 42 additions & 60 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,6 @@ def effective_workspace(self) -> Optional[torch.Tensor]:
"""Attention-kernel workspace, switching to the CUDA-graph copy under capture."""
return self.cuda_graph_workspace if self.is_cuda_graph else self.workspace

@property
def helix_tensor_params(self) -> List[Optional[torch.Tensor]]:
"""``[helix_position_offsets, helix_is_inactive_rank]`` — the positional
helix tensor list expected by the C++ attention op."""
return [self.helix_position_offsets, self.helix_is_inactive_rank]

@property
def spec_decoding_bool_params(self) -> List[bool]:
"""``[is_spec_decoding_enabled, use_spec_decoding, is_spec_dec_tree]`` —
the positional bool list expected by the C++ attention op."""
return [
self.is_spec_decoding_enabled,
self.use_spec_decoding,
self.is_spec_dec_tree,
]

@property
def spec_decoding_position_offsets_for_cpp(self) -> Optional[torch.Tensor]:
"""``spec_decoding_position_offsets`` reshaped to the 2D layout the C++
Expand Down Expand Up @@ -1051,22 +1035,6 @@ def generate_spec_decoding_generation_length(self, runtime_draft_len):
def is_sm_version_trtllm_gen_kernel(self, sm):
return not (sm < 100 or sm in [120, 121])

@property
def spec_decoding_tensor_params(self) -> List[Optional[torch.Tensor]]:
"""Positional spec-decoding tensor list for the C++ attention op.
Includes three Blackwell-tree mask tensors on SM versions that take
the trtllm-gen kernel."""
params = [
self.spec_decoding_generation_lengths,
self.spec_decoding_position_offsets_for_cpp,
self.spec_decoding_packed_mask,
]
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
params.append(self.spec_decoding_bl_tree_mask_offset)
params.append(self.spec_decoding_bl_tree_mask)
params.append(self.spec_bl_tree_first_sparse_mask_offset_kv)
return params


class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):

Expand Down Expand Up @@ -1332,35 +1300,36 @@ def create_output(self, q, *, is_quantize_output: bool,
]

@property
def rotary_embedding_dim(self) -> int:
def rope_dim(self) -> int:
return self.rope_params.dim

@property
def rotary_embedding_base(self) -> float:
def rope_base(self) -> float:
return self.rope_params.theta

@property
def rotary_embedding_scale_type(self) -> int:
def rope_scale_type(self) -> int:
return int(self.rope_params.scale_type)

@property
def rotary_embedding_scales(self) -> List[float]:
"""``[scale, short_m_scale, long_m_scale]`` — the positional RoPE-scale
list expected by the C++ attention op."""
return [
self.rope_params.scale,
self.rope_params.short_m_scale,
self.rope_params.long_m_scale,
]
def rope_scale(self) -> float:
return self.rope_params.scale

@property
def rotary_embedding_max_position_info(self) -> List[int]:
"""``[max_positions, original_max_positions]`` — the positional
RoPE-positions list expected by the C++ attention op."""
return [
self.rope_params.max_positions,
self.rope_params.original_max_positions,
]
def rope_short_m_scale(self) -> float:
return self.rope_params.short_m_scale

@property
def rope_long_m_scale(self) -> float:
return self.rope_params.long_m_scale

@property
def rope_max_positions(self) -> int:
return self.rope_params.max_positions

@property
def rope_original_max_positions(self) -> int:
return self.rope_params.original_max_positions

@property
def skip_softmax_threshold_scale_factor_prefill(self) -> Optional[float]:
Expand Down Expand Up @@ -1530,10 +1499,21 @@ def _run(
max_num_requests=metadata.max_num_requests,
beam_width=metadata.beam_width,
use_paged_context_fmha=metadata.use_paged_context_fmha,
helix_tensor_params=metadata.helix_tensor_params,
spec_decoding_bool_params=metadata.spec_decoding_bool_params,
spec_decoding_tensor_params=metadata.
spec_decoding_tensor_params,
helix_position_offsets=metadata.helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
use_spec_decoding=metadata.use_spec_decoding,
is_spec_dec_tree=metadata.is_spec_dec_tree,
spec_decoding_generation_lengths=metadata.
spec_decoding_generation_lengths,
spec_decoding_position_offsets_for_cpp=metadata.
spec_decoding_position_offsets_for_cpp,
spec_decoding_packed_mask=metadata.spec_decoding_packed_mask,
spec_decoding_bl_tree_mask_offset=metadata.
spec_decoding_bl_tree_mask_offset,
spec_decoding_bl_tree_mask=metadata.spec_decoding_bl_tree_mask,
spec_bl_tree_first_sparse_mask_offset_kv=metadata.
spec_bl_tree_first_sparse_mask_offset_kv,
num_sparse_topk=metadata.num_sparse_topk,
flash_mla_tile_scheduler_metadata=metadata.
flash_mla_tile_scheduler_metadata,
Expand Down Expand Up @@ -1584,12 +1564,14 @@ def _run(
quant_mode=self.quant_mode,
q_scaling=self.q_scaling,
position_embedding_type=self.position_embedding_type,
rotary_embedding_dim=self.rotary_embedding_dim,
rotary_embedding_base=self.rotary_embedding_base,
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
rotary_embedding_scales=self.rotary_embedding_scales,
rotary_embedding_max_position_info=self.
rotary_embedding_max_position_info,
rope_dim=self.rope_dim,
rope_base=self.rope_base,
rope_scale_type=self.rope_scale_type,
rope_scale=self.rope_scale,
rope_short_m_scale=self.rope_short_m_scale,
rope_long_m_scale=self.rope_long_m_scale,
rope_max_positions=self.rope_max_positions,
rope_original_max_positions=self.rope_original_max_positions,
is_mla_enable=self.is_mla_enable,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
Expand Down
Loading
Loading