Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ TRTEngine::TRTEngine(
}
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");

// Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses)
cudaMalloc(&empty_tensor_placeholder, 1);

runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
runtime_states.old_pre_allocated_outputs = false;
runtime_states.context_changed = false;
Expand Down Expand Up @@ -264,6 +267,9 @@ TRTEngine::~TRTEngine() {
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
if (empty_tensor_placeholder) {
cudaFree(empty_tensor_placeholder);
}
rt.reset();
}

Expand Down
3 changes: 3 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;

// Single placeholder buffer for empty tensor inputs (allocated once, reused)
void* empty_tensor_placeholder = nullptr;

// Output Allocator-Related Functionality
bool requires_output_allocator = false; // engine requires output allocator
bool use_output_allocator_outputs = false; // users specify to use output allocator
Expand Down
20 changes: 14 additions & 6 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,26 @@ void setup_input_tensors(
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

at::Tensor final_input;
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
"Error while setting the input tensor address for inputs");
final_input = compiled_engine->input_buffers[i];
} else {
// Otherwise use the formatted buffer directly
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
"Error while setting the input tensor address for inputs");
final_input = formatted_inputs.back();
}

// Get tensor address, using placeholder for empty tensors
// TensorRT requires non-null address even if numel() = 0
// empty_tensor_placeholder is pre-allocated in TRTEngine constructor
void* input_addr = (final_input.numel() == 0 || final_input.data_ptr() == nullptr)
? compiled_engine->empty_tensor_placeholder
: final_input.data_ptr();

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), input_addr),
"Failed to bind tensor address for " << name);
}
}
}
Expand Down
16 changes: 14 additions & 2 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
from torch_tensorrt._enums import dtype, memory_format

logger = logging.getLogger(__name__)


class Input(object):
"""
Expand Down Expand Up @@ -149,6 +152,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
}
self.shape_mode = Input._ShapeMode.DYNAMIC

# Warn if min_shape has any 0 dimension (empty tensor) - TensorRT doesn't support this
# @apbose: Is this warning necessary?
if any(dim == 0 for dim in self.shape["min_shape"]):
logger.warning(
f"min_shape contains a 0 dimension: {self.shape['min_shape']}. "
"TensorRT does not support dynamic shapes with min dimension of 0 (empty tensors). "
"TensorRT will internally clamp min dimensions to 1, which may cause runtime errors "
"if you try to run inference with empty tensor inputs."
)

else:
raise ValueError(
f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments"
Expand Down Expand Up @@ -384,7 +397,7 @@ def example_tensor(
dtype=self.dtype.to(torch.dtype, use_default=True)
)
else:
RuntimeError(
raise RuntimeError(
f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})"
)
else:
Expand Down Expand Up @@ -412,4 +425,3 @@ def example_tensor(
raise ValueError(
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
)
raise
23 changes: 9 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import tensorrt as trt
import torch
import torch_tensorrt
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
Expand Down Expand Up @@ -536,7 +535,7 @@ def __contains__(self, key: Target | Node) -> bool:
def get_all_converters_with_target(
self, key: Target, return_registry_info: bool = False
) -> Tuple[
Union[List[Any], Dict[str, int], None]
List[Any], Optional[Dict[str, int]]
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
"""Get all converters across all registries for the target

Expand All @@ -547,7 +546,7 @@ def get_all_converters_with_target(

# Store count of number of registered converters per registry
if return_registry_info:
registry_data = {name: 0 for name in self.registry_names}
registry_data = dict.fromkeys(self.registry_names, 0)

for index, registry in enumerate(self.registries):
if key in registry:
Expand Down Expand Up @@ -622,22 +621,18 @@ def display_all_available_converters(self) -> str:
return available_converters


# Initialize dynamo converter registry with the FX and Dynamo aten registries
# Note the Dynamo registry is listed first, for precedence
registries = [
DYNAMO_ATEN_CONVERTERS,
# Initialize dynamo converter registry with Dynamo aten converters only
# FX converters are not loaded here - they are legacy and should only be used
# in the FX frontend, not as fallbacks in the dynamo frontend
registries: List[
Dict[Target, Union[Callable[..., Any], Sequence[ConverterSupport]]]
] = [
DYNAMO_ATEN_CONVERTERS, # type: ignore[list-item]
]
registry_names = ["Dynamo ATen Converters Registry"]
registry_calling_conventions = [
CallingConvention.CTX,
]
if torch_tensorrt.ENABLED_FEATURES.fx_frontend:
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS

registries.append(FX_CONVERTERS)
registry_names.append("FX Legacy ATen Converters Registry")
registry_calling_conventions.append(CallingConvention.LEGACY)


DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(
registries,
Expand Down
92 changes: 92 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,61 @@ def parse_cat_args(
return input_tensors, dim


def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
"""
Validator for torch.cat operation with empty tensor handling.

PyTorch allows torch.tensor([]) (shape (0,)) to be concatenated with higher-dimensional
tensors, but TensorRT requires all inputs to have the same rank. This validator catches
this specific edge case.

Example valid case: cat([(3, 4), (0, 4)], dim=0) - same rank, properly shaped empty tensor for TRT
Example invalid case: cat([(3, 4), (0,)], dim=0) - torch.tensor([]) with rank mismatch
"""
# Use parse_cat_args to properly extract inputs (handles both args and kwargs patterns)
inputs, _ = parse_cat_args(node.args, node.kwargs)

if len(inputs) < 2:
return True

# Collect metadata for all inputs
input_metas = []
for inp in inputs:
if isinstance(inp, TRTTensor):
# TRTTensor has shape directly
input_metas.append(inp.shape)
else:
# For nodes, get metadata
meta = getattr(inp, "meta", {}).get("tensor_meta")
if meta is None:
# Can't validate without metadata, allow it
return True
shape = tuple(meta.shape)
input_metas.append(shape)

# Check for the specific problematic case:
# 1D empty tensor (0,) being concatenated with higher-dimensional tensors
ranks = [len(shape) for shape in input_metas]
# If all ranks are the same, it's fine (PyTorch and TensorRT both handle this)
if len(set(ranks)) == 1:
return True
# If ranks differ, check if we have a 1D empty tensor (0,) in the mix
# This is the torch.tensor([]) case that PyTorch allows but TensorRT doesn't
for i, shape in enumerate(input_metas):
if shape == (0,) or (len(shape) == 1 and shape[0] == 0):
# Found a 1D empty tensor with rank mismatch
_LOGGER.debug(
f"Concatenation rejected by TRT, torch.tensor([]) or 1D empty tensor at position {i} "
f"PyTorch allows this but TensorRT requires all inputs to have the same rank. "
f"Use torch.empty((0, ...)) with explicit dimensions matching other inputs instead. Falling back to Pytorch"
)
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.cat.default,
capability_validator=cat_validator,
supports_dynamic_shapes=True,
)
def aten_ops_cat(
Expand Down Expand Up @@ -413,6 +466,27 @@ def aten_ops_relu(
)


@dynamo_tensorrt_converter(
torch.ops.aten.hardtanh.default, supports_dynamic_shapes=True
)
def aten_ops_hardtanh(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.activation.hardtanh(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, -1.0),
args_bounds_check(args, 2, 1.0),
)


@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default, supports_dynamic_shapes=True)
def aten_ops_sigmoid(
ctx: ConversionContext,
Expand Down Expand Up @@ -446,6 +520,24 @@ def aten_ops_symsize_int(
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(
torch.ops.aten.sym_numel.default, supports_dynamic_shapes=True
)
def aten_ops_sym_numel(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.numel(ctx, target, SourceIR.ATEN, name, args[0])


def index_dtype_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
Expand Down
36 changes: 36 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,42 @@ def create_constant(
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
)
# Record the weight in ctx for refit and cpu memory reference
# TensorRT's add_constant doesn't support 0-element tensors,
# but TRT does support empty tensors at runtime.
# For empty constants, we create a larger constant and slice it to empty.
if torch_value.numel() == 0:
empty_shape = list(torch_value.shape)

# Create a placeholder shape where each dim is max(1, target_dim)
# This ensures we can slice to the target empty shape
# e.g., target [0, 4] -> placeholder [1, 4] -> slice to [0, 4]
placeholder_shape = [max(1, d) for d in empty_shape]
placeholder_numel = 1
for d in placeholder_shape:
placeholder_numel *= d

# Create placeholder constant with the required number of elements
placeholder_value = torch.zeros(placeholder_numel, dtype=torch_value.dtype)
placeholder_weights = to_trt_weights(
ctx, placeholder_value, f"{name}_placeholder", "CONSTANT", "CONSTANT"
)
placeholder_constant = ctx.net.add_constant(
tuple(placeholder_shape), placeholder_weights
)
placeholder_constant.name = f"{name}_placeholder"

# Slice to get the empty shape (at least one dimension is 0)
start = [0] * len(empty_shape)
stride = [1] * len(empty_shape)
slice_layer = ctx.net.add_slice(
placeholder_constant.get_output(0),
start=start,
shape=empty_shape,
stride=stride,
)
slice_layer.name = f"{name}_empty_slice"

return slice_layer.get_output(0)

# Convert the torch.Tensor to a trt.Weights object
trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT")
Expand Down
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,27 @@ def gelu(
operation_type,
input_val,
)


def hardtanh(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
min_val: float = -1.0,
max_val: float = 1.0,
) -> TRTTensor:
# Ported from fx/converters/impl/activation.py
# dyn_range_fn removed as it's not used in dynamo's convert_activation base
operation_type = trt.ActivationType.CLIP
return convert_activation(
ctx,
target,
source_ir,
name,
operation_type,
input_val,
alpha=min_val,
beta=max_val,
)
Loading
Loading