Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import platform
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)

import torch
from torch_tensorrt._enums import dtype
Expand Down Expand Up @@ -183,6 +194,7 @@ def compile(
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[Dict[str, Any]] = None,
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
dynamic_shapes: Optional[Any] = None,
**kwargs: Any,
) -> (
torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any]
Expand Down Expand Up @@ -218,6 +230,14 @@ def compile(
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
dynamic_shapes (Any): Optional ``dynamic_shapes`` dict (or list / nested
structure) forwarded to ``torch.export.export``. Supply this to share a
``Dim`` across multiple inputs (e.g. when ``input_ids`` and ``attention_mask``
must have the same batch size at runtime). When omitted, dynamic shapes are
auto-inferred from per-input ``min_shape``/``max_shape`` and **each input gets
its own independent symbol** -- which fails ``torch.export``'s constraint
check for models that broadcast across these axes. Only consulted when
``module`` is an ``nn.Module`` (ignored for ``ExportedProgram``).
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)

Returns:
Expand Down Expand Up @@ -288,7 +308,7 @@ def _fx_input_interface(
return compiled_fx_module
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
if arg_inputs is None and inputs is None:
if arg_inputs is None and inputs is None and not kwarg_inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

elif arg_inputs is not None and inputs is not None:
Expand All @@ -303,8 +323,10 @@ def _fx_input_interface(

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore
if arg_inputs is None:
arg_inputs = []
elif not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs]

torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
Expand All @@ -316,6 +338,7 @@ def _fx_input_interface(
module,
torchtrt_arg_inputs,
kwarg_inputs=torchtrt_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
**kwargs,
)
trt_graph_module = dynamo_compile(
Expand Down Expand Up @@ -793,8 +816,8 @@ def _all_are_input_objects(obj: Any) -> bool:
f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}"
)

arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore
arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device()))
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device())

else:
# Mixed case: some inputs are Tensors, some are Input objects
Expand Down
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def trace(
*,
arg_inputs: Optional[Tuple[Any, ...]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
dynamic_shapes: Optional[Any] = None,
**kwargs: Any,
) -> torch.export.ExportedProgram:
"""Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT
Expand Down Expand Up @@ -65,17 +66,19 @@ def trace(
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)
arg_inputs = inputs or arg_inputs
arg_inputs = inputs if inputs is not None else arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

device = to_torch_device(kwargs.get("device", default_device()))
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
# Constructing dynamic shape list as a nested dict
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
if dynamic_shapes is None:
# Auto-inferred dims are independent per input; pass dynamic_shapes
# explicitly to share a Dim across inputs.
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
exp_program = export(
mod,
tuple(torch_arg_inputs),
Expand Down
250 changes: 250 additions & 0 deletions tests/py/dynamo/models/test_shared_dynamic_dim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# type: ignore
"""
Tests for the ``dynamic_shapes=`` passthrough kwarg on ``torch_tensorrt.compile``.

Background: when a model takes multiple inputs whose dynamic axes must be
**equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask``
both shaped ``[B, S]``), the legacy auto-inference path in
``dynamo/_tracer.py`` mints an *independent* ``Dim`` per input. ``torch.export``
then fails its constraint check for any forward() that broadcasts across those
axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising
``ConstraintViolationError``.

These tests exercise the new ``dynamic_shapes=`` passthrough that lets the
caller supply a shared ``Dim`` directly to ``torch_tensorrt.compile`` --
mirroring the ``torch.export.export(dynamic_shapes=...)`` signature -- so the
shared-batch case compiles end to end without the caller having to pre-export
the module themselves.
"""
import unittest

import pytest
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt
from torch.export import Dim
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()


class _SharedBatchEncoder(nn.Module):
"""HF-style encoder stand-in: two int64 inputs sharing the batch axis.

The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces
``input_ids.size(0) == attention_mask.size(0)`` -- the relationship the
auto-inference path cannot express.
"""

def __init__(self, vocab: int = 1024, hidden: int = 32):
super().__init__()
self.embed = nn.Embedding(vocab, hidden)
self.proj = nn.Linear(hidden, hidden)

def forward(self, input_ids, attention_mask):
x = self.embed(input_ids)
mask = attention_mask.unsqueeze(-1).to(x.dtype)
return self.proj(x * mask)


def _kwarg_inputs(seq: int = 16, batch_min: int = 1, batch_max: int = 4):
return {
"input_ids": torchtrt.Input(
min_shape=(batch_min, seq),
opt_shape=(batch_max, seq),
max_shape=(batch_max, seq),
dtype=torch.int64,
name="input_ids",
),
"attention_mask": torchtrt.Input(
min_shape=(batch_min, seq),
opt_shape=(batch_max, seq),
max_shape=(batch_max, seq),
dtype=torch.int64,
name="attention_mask",
),
}


@pytest.mark.unit
@pytest.mark.critical
def test_dynamic_shapes_passthrough_with_shared_batch_dim():
"""With ``dynamic_shapes={..: {0: batch}, ..: {0: batch}}`` (one shared
``Dim``), compile succeeds and the engine matches the eager model."""
model = _SharedBatchEncoder().eval().cuda()

batch = Dim("batch", min=1, max=4)
dynamic_shapes = {
"input_ids": {0: batch},
"attention_mask": {0: batch},
}

trt_mod = torchtrt.compile(
model,
ir="dynamo",
kwarg_inputs=_kwarg_inputs(),
dynamic_shapes=dynamic_shapes,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)

# Sample at the optimization shape and at a smaller batch within the range.
for bs in (4, 2):
ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda")
mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda")

with torch.no_grad():
ref = model(input_ids=ids, attention_mask=mask)
out = trt_mod(input_ids=ids, attention_mask=mask)

cos_sim = cosine_similarity(ref, out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Shared-batch encoder out-of-tolerance at bs={bs}: cos_sim={cos_sim}",
)


@pytest.mark.unit
def test_dynamic_shapes_passthrough_positional_tuple_form():
"""``torch.export`` also accepts ``dynamic_shapes`` as a tuple matching the
positional-args order. Verify the passthrough handles that form too."""
model = _SharedBatchEncoder().eval().cuda()

batch = Dim("batch", min=1, max=4)
seq = 16
positional_inputs = [
torchtrt.Input(
min_shape=(1, seq),
opt_shape=(4, seq),
max_shape=(4, seq),
dtype=torch.int64,
name="input_ids",
),
torchtrt.Input(
min_shape=(1, seq),
opt_shape=(4, seq),
max_shape=(4, seq),
dtype=torch.int64,
name="attention_mask",
),
]
# Tuple form: one entry per positional arg, in declaration order.
dynamic_shapes = ({0: batch}, {0: batch})

trt_mod = torchtrt.compile(
model,
ir="dynamo",
inputs=positional_inputs,
dynamic_shapes=dynamic_shapes,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)

for bs in (4, 2):
ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda")
mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda")

with torch.no_grad():
ref = model(ids, mask)
out = trt_mod(ids, mask)

cos_sim = cosine_similarity(ref, out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Tuple-form dynamic_shapes out-of-tolerance at bs={bs}: cos_sim={cos_sim}",
)


@pytest.mark.unit
def test_dynamic_shapes_passthrough_mixed_args_and_kwargs():
"""One positional input, one kwarg input, sharing a batch ``Dim``. Uses the
unified dict-by-parameter-name form, which spans both positional and keyword
parameters."""
model = _SharedBatchEncoder().eval().cuda()

batch = Dim("batch", min=1, max=4)
seq = 16

# input_ids passed positionally, attention_mask as a kwarg.
positional_inputs = [
torchtrt.Input(
min_shape=(1, seq),
opt_shape=(4, seq),
max_shape=(4, seq),
dtype=torch.int64,
name="input_ids",
),
]
kwarg_inputs = {
"attention_mask": torchtrt.Input(
min_shape=(1, seq),
opt_shape=(4, seq),
max_shape=(4, seq),
dtype=torch.int64,
name="attention_mask",
),
}
dynamic_shapes = {
"input_ids": {0: batch},
"attention_mask": {0: batch},
}

trt_mod = torchtrt.compile(
model,
ir="dynamo",
inputs=positional_inputs,
kwarg_inputs=kwarg_inputs,
dynamic_shapes=dynamic_shapes,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)

for bs in (4, 2):
ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda")
mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda")

with torch.no_grad():
ref = model(ids, attention_mask=mask)
out = trt_mod(ids, attention_mask=mask)

cos_sim = cosine_similarity(ref, out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Mixed args/kwargs out-of-tolerance at bs={bs}: cos_sim={cos_sim}",
)


@pytest.mark.unit
def test_dynamic_shapes_default_path_unchanged_for_static_inputs():
"""Sanity check: when ``dynamic_shapes=None`` and inputs are fully static,
behavior is unchanged from the legacy path."""

class StaticModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)

def forward(self, x):
return self.linear(x)

model = StaticModel().eval().cuda()
trt_mod = torchtrt.compile(
model,
ir="dynamo",
inputs=[torchtrt.Input(shape=(2, 8), dtype=torch.float32, name="x")],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
x = torch.randn((2, 8), device="cuda")
with torch.no_grad():
ref = model(x)
out = trt_mod(x)
assertions.assertTrue(cosine_similarity(ref, out) > COSINE_THRESHOLD)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading