Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/dynamo/compile_with_dynamic_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging

import torch
import torch.nn as nn
import torch_tensorrt

logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)


class ExpandReshapeModel(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.embed_dim = embed_dim
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = self.qkv_proj(x)
reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()

# 1. JIT: torch.compile
x1 = x.clone()
torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out1 = trt_module(x1)

# 2. AOT: torch_tensorrt.compile
x2 = x.clone()
example_input = torch_tensorrt.Input(
min_shape=[1, 196, 768],
opt_shape=[4, 196, 768],
max_shape=[32, 196, 768],
dtype=torch.float32,
)
trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input)
out2 = trt_module(x2)

# 3. AOT: torch.export + Dynamo compile
x3 = x.clone()
bs = torch.export.Dim("bs", min=1, max=32)
dynamic_shapes = {"x": {0: bs}}
exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)
trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))
out3 = trt_module(x3)

assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
assert torch.allclose(out2, out3)
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(
torch.ops.aten._fused_rms_norm.default,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_fused_rms_norm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.fused_rms_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
normalized_shape=args[1],
weight=args_bounds_check(args, 2),
eps=args_bounds_check(args, 3),
)


def parse_cat_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Tuple[List[Any], int]:
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ def get_trt_tensor(
# If the input is 64-bit, cast it to 32-bit for TRT freezing
if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double:
if input_val.dtype == torch.float64:
input_val = input_val.to(torch.float32)
with unset_fake_temporarily():
input_val = input_val.to(torch.float32)
elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double:
if input_val.dtype == np.float64:
input_val = input_val.astype(np.float32)
Expand Down
86 changes: 86 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,89 @@ def cdist_forward(
return_indices=False,
)
return dist


def fused_rms_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: trt.ITensor,
normalized_shape: List[int],
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
eps: Optional[float],
) -> Tuple[trt.ITensor, torch.Tensor]:
"""
RMS Normalization: output = input / sqrt(mean(input^2) + eps) * weight

Args:
ctx: ConversionContext containing the TensorRT network
target: Target of calling node
source_ir: SourceIR of calling converter
name: Name of the calling layer
input: Input tensor to normalize
normalized_shape: Shape over which to normalize (list of ints)
weight: Optional weight/scale parameter
eps: Epsilon for numerical stability (default: 1e-5)

Returns:
Tuple of (normalized_output, rstd_placeholder)
Note: rstd (reciprocal standard deviation) is returned as None placeholder
"""
if eps is None:
eps = 1e-5

# Calculate dimensions to normalize over (similar to layer_norm)
# normalized_shape specifies the last N dimensions
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)

# Square the input
input_squared = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_input_squared", input, input
)

# Compute mean of squared values
mean_squared = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean_squared", input_squared, dim=dims, keepdim=True
)

# Add epsilon for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", input.dtype)
mean_squared_eps = impl.elementwise.add(
ctx, target, source_ir, f"{name}_mean_squared_eps", mean_squared, eps_tensor
)

# Compute RMS = sqrt(mean(input^2) + eps)
rms = impl.unary.sqrt(ctx, target, source_ir, f"{name}_rms", mean_squared_eps)

# Normalize: input / rms
normalized = impl.elementwise.div(
ctx, target, source_ir, f"{name}_normalized", input, rms
)

# Apply weight (scale) if provided
if weight is not None:
weight_trt = get_trt_tensor(ctx, weight, f"{name}_weight")

# Cast weight to match input dtype
weight_trt = cast_trt_tensor(
ctx, weight_trt, input.dtype, f"{name}_weight_cast", target, source_ir
)

# Expand weight to match input shape if needed
if tuple(input.shape) != tuple(weight_trt.shape):
weight_trt = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight_trt, input.shape
)

# Multiply normalized output by weight
output = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_output", normalized, weight_trt
)
else:
output = normalized

# Return (output, rstd_placeholder)
# PyTorch returns (output, rstd) but we return None for rstd as it's typically not used
return output, None
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def remove_sym_nodes(
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
gm = replace_symint_with_sym_size(gm)

# Extract SymInt placeholder Tensors
placeholder_idx_sym_ints = [
(idx, node)
Expand All @@ -36,3 +38,56 @@ def remove_sym_nodes(
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")

return gm


def replace_symint_with_sym_size(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Replace SymInt placeholders with sym_size nodes"""
# Find all SymInt placeholders and their args
symint_node_arg_dict = {}
for node in gm.graph.nodes:
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.SymInt)
):
ga = node.meta.get("grapharg", None)
if ga is not None:
src = ga.source # TensorPropertySource
symint_node_arg_dict[node] = (src.base.local_name, src.idx)

# Replace SymInt placeholders with sym_size nodes
for node in gm.graph.nodes:
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
):
ga = node.meta.get("grapharg", None)
if ga is not None:
src = ga.source
if hasattr(src, "local_name") and getattr(src, "is_input", False):
node_local_name = src.local_name
for symint_node, (
symint_local_name,
idx,
) in symint_node_arg_dict.items():
if node_local_name == symint_local_name:
with gm.graph.inserting_after(node):
size_node = gm.graph.call_function(
torch.ops.aten.sym_size, args=(node, idx)
)
symint_node.replace_all_uses_with(size_node)
logger.debug(
f"The SymInt node {symint_node} is replaced with the sym_size node {size_node}"
)
# the symint_node is not used anymore, but it cannot be directly erased here
# because it will cause the number of positional arguments mismatch error.
# The node will be removed in the outside of the function

gm.graph.lint()
gm.recompile()
logger.debug(f"Added sym_size nodes for SymInt placeholders:\n{gm.graph}")

return gm
Loading