Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
756 changes: 756 additions & 0 deletions examples/dynamo/attention_plugin_example.py

Large diffs are not rendered by default.

404 changes: 404 additions & 0 deletions examples/dynamo/end_to_end_llm_generation_example.py

Large diffs are not rendered by default.

559 changes: 559 additions & 0 deletions examples/dynamo/end_to_end_vit_attention_plugin_example.py

Large diffs are not rendered by default.

494 changes: 494 additions & 0 deletions examples/dynamo/vit_attention_plugin_example.py

Large diffs are not rendered by default.

96 changes: 96 additions & 0 deletions tools/llm/plugin_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
TensorRT converter for Edge-LLM attention plugin ops.

This module contains the TensorRT converter for the tensorrt_edge_llm::xqa_attn
custom op. It is kept in a separate file from plugin_utils.py for maintainability.
"""

import numpy as np
import tensorrt as trt
from plugin_utils import get_plugin_config, register_plugin_op
from torch_tensorrt.dynamo.conversion import (
ConversionContext,
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor

# Ensure the custom op is registered before the converter decorator runs
register_plugin_op()

import torch # noqa: E402 (must be after register_plugin_op so the op exists)

@dynamo_tensorrt_converter(
torch.ops.tensorrt_edge_llm.xqa_attn.default, supports_dynamic_shapes=True
)
def convert_attn(ctx: ConversionContext, target, args, kwargs, name):
"""
Convert tensorrt_edge_llm::xqa_attn op to TensorRT AttentionPlugin.

TensorRT-Edge-LLM (0.4.0) plugin requires 5 inputs:
- qkv, kv, ctx_len, rope, kv_cache_start_idx

Plugin fields:
- num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output
"""
# args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d
qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8]

creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "")
if creator is None:
raise RuntimeError("AttentionPlugin not found in TensorRT plugin registry!")

# Get config from global settings
config = get_plugin_config()
if config:
nq_val = config["num_attention_heads"]
nkv_val = config["num_key_value_heads"]
d_val = config["head_dim"]
else:
# Fallback to values from args (may not work correctly)
nq_val = nq if isinstance(nq, int) else 14
nkv_val = nkv if isinstance(nkv, int) else 2
d_val = d if isinstance(d, int) else 64

# Plugin fields for TensorRT-Edge-LLM AttentionPlugin
# Required: num_q_heads, num_kv_heads, head_size, enable_tree_attention
# enable_delta_kv_output=1 enables delta KV output for Python/torch_tensorrt compatibility
field_list = [
trt.PluginField(
field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32
)
for field_name, field_val in [
("num_q_heads", nq_val),
("num_kv_heads", nkv_val),
("head_size", d_val),
("enable_tree_attention", 0),
("enable_delta_kv_output", 1),
]
]

fields = trt.PluginFieldCollection(field_list)
plugin = creator.create_plugin(name, fields)

# 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx
inputs = [
(
get_trt_tensor(ctx, i, f"{name}_i{idx}")
if not isinstance(i, trt.ITensor)
else i
)
for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx])
]

# Handle ctx_len shape if needed (squeeze if [B, 1] -> [B])
if len(inputs[2].shape) == 2 and inputs[2].shape[1] == 1:
shuffle_layer = ctx.net.add_shuffle(inputs[2])
shuffle_layer.reshape_dims = (inputs[2].shape[0],)
inputs[2] = shuffle_layer.get_output(0)

# Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B])
if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1:
shuffle_layer = ctx.net.add_shuffle(inputs[4])
shuffle_layer.reshape_dims = (inputs[4].shape[0],)
inputs[4] = shuffle_layer.get_output(0)

layer = ctx.net.add_plugin_v2(inputs, plugin)
return layer.get_output(0), layer.get_output(1)
86 changes: 86 additions & 0 deletions tools/llm/plugin_converter_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
TensorRT converter for ViT attention plugin ops.

This module contains the TensorRT converter for the tensorrt_edge_llm::xqa_attn
custom op. It is kept in a separate file from plugin_utils.py for maintainability.
"""

import numpy as np
import tensorrt as trt

from plugin_utils_vit import get_vit_plugin_config, register_vit_plugin_op
from torch_tensorrt.dynamo.conversion import (
ConversionContext,
dynamo_tensorrt_converter,
)

from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor

register_vit_plugin_op()

import torch # noqa: E402 (must be after register_vit_plugin_op so the op exists)

@dynamo_tensorrt_converter(
torch.ops.tensorrt_vit.attention.default, supports_dynamic_shapes=True
)
def convert_vit_attention(ctx: ConversionContext, target, args, kwargs, name):
"""Convert tensorrt_vit::attention to TensorRT ViTAttentionPlugin."""
qkv, cos, sin, attention_mask, num_heads, head_dim = args[:6]
qkv_fused = args[6] if len(args) > 6 else kwargs.get("qkv_fused", 1)

creator = trt.get_plugin_registry().get_plugin_creator(
"ViTAttentionPlugin", "1", ""
)
if creator is None:
raise RuntimeError(
"ViTAttentionPlugin not found in TensorRT plugin registry!"
)

config = get_vit_plugin_config()
num_heads_val = config.get("num_attention_heads", num_heads)
head_dim_val = config.get("head_dim", head_dim)
qkv_fused_val = qkv_fused if isinstance(qkv_fused, int) else 1

field_list = [
trt.PluginField(
"num_heads",
np.array([num_heads_val], dtype=np.int32),
trt.PluginFieldType.INT32,
),
trt.PluginField(
"head_size",
np.array([head_dim_val], dtype=np.int32),
trt.PluginFieldType.INT32,
),
trt.PluginField(
"qkv_fused",
np.array([qkv_fused_val], dtype=np.int32),
trt.PluginFieldType.INT32,
),
]
plugin = creator.create_plugin(name, trt.PluginFieldCollection(field_list))
if plugin is None:
raise RuntimeError("Failed to create ViTAttentionPlugin")

qkv_tensor = (
get_trt_tensor(ctx, qkv, f"{name}_qkv")
if not isinstance(qkv, trt.ITensor)
else qkv
)
cos_tensor = (
get_trt_tensor(ctx, cos, f"{name}_cos")
if not isinstance(cos, trt.ITensor)
else cos
)
sin_tensor = (
get_trt_tensor(ctx, sin, f"{name}_sin")
if not isinstance(sin, trt.ITensor)
else sin
)
mask_tensor = (
get_trt_tensor(ctx, attention_mask, f"{name}_mask")
if not isinstance(attention_mask, trt.ITensor)
else attention_mask
)
layer = ctx.net.add_plugin_v2([qkv_tensor, cos_tensor, sin_tensor, mask_tensor], plugin)
return layer.get_output(0)
Loading
Loading