Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 147 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
import platform
import warnings
from typing import Any, Collection, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union

import sympy
import torch
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch.utils._sympy.numbers import int_oo
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
Expand Down Expand Up @@ -791,7 +793,7 @@ def _insert_complex_io_adapters(
Outputs: insert view_as_complex before the output node for each originally-complex
output that comes from a TRT block.

Leverages metadata that was captued when the complex rewriter pass was run
Leverages metadata that was captured when the complex rewriter pass was run
"""
complex_input_names = gm.meta.get("complex_input_names", [])
complex_input_dtypes = gm.meta.get("complex_input_dtypes", {})
Expand Down Expand Up @@ -875,6 +877,140 @@ def _insert_complex_io_adapters(
partitioned_module.recompile()


def _build_user_symbol_bounds(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: dict[Any, Any],
) -> Dict[sympy.Symbol, Tuple[int, int]]:
"""Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to
fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``.

Validates against finite exporter bounds: ``user_max > exp_max`` and
``user_min < exp_min`` raise (TRT will reject those shapes at runtime);
a strict subset warns (engine profile widens to the exporter); the
``user_min=1, exp_min=2`` case warns only -- it's PyTorch's 0/1
specialization artifact, not a user error.
"""
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]

sample_by_name: dict[str, Input] = {}
for i, node in enumerate(placeholders):
if i < len(sample_arg_inputs):
inp = sample_arg_inputs[i]
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[node.target] = inp
for name, inp in sample_kwarg_inputs.items():
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[name] = inp

user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {}
if not sample_by_name:
return user_symbol_bounds

for node in placeholders:
if node.target not in sample_by_name:
continue
sample_input = sample_by_name[node.target]
fake_val = node.meta.get("val")
if not isinstance(fake_val, torch.Tensor):
continue

min_shape = sample_input.shape["min_shape"]
max_shape = sample_input.shape["max_shape"]

for d, dim in enumerate(fake_val.size()):
if not isinstance(dim, torch.SymInt) or d >= len(min_shape):
continue
expr = dim.node.expr
# Composite exprs (e.g. ``2*s0``) are recomputed by
# ``ShapeEnv.bound_sympy``; overriding them directly would lie.
if not isinstance(expr, sympy.Symbol):
continue
if expr in user_symbol_bounds:
continue
user_min = int(min_shape[d])
user_max = int(max_shape[d])
user_symbol_bounds[expr] = (user_min, user_max)
logger.debug(
"Recorded user-supplied bounds for %s: [%d, %d]",
expr,
user_min,
user_max,
)

# If exporter bounds are finite, ``extract_var_range_info`` keeps
# them (override is gated on ``max_val is None``). Catch the
# mismatch here so the user doesn't hit a runtime "shape outside
# profile" error on shapes they explicitly declared.
shape_env = getattr(dim.node, "shape_env", None)
if shape_env is None:
continue
exp_range = shape_env.var_to_range.get(expr)
if exp_range is None:
continue
exp_lower = exp_range.lower
exp_upper = exp_range.upper
exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo
if exp_max_unbounded:
# Dim.DYNAMIC: user fills the gap (intended use).
continue
try:
exp_min = int(exp_lower)
exp_max = int(exp_upper)
except (TypeError, ValueError):
continue
if user_min == exp_min and user_max == exp_max:
continue

mismatch = (
f"symbol {expr}: Input({user_min}, {user_max}) vs "
f"exporter({exp_min}, {exp_max})."
)
hint = (
f" Re-export with Dim('{expr}', min={user_min}, "
f"max={user_max}) or adjust Input to match."
)

if user_max > exp_max:
raise ValueError(
f"{mismatch} Input.max_shape exceeds the exporter's max "
f"({user_max} > {exp_max}); TRT will reject shapes above "
f"{exp_max} at runtime.{hint}"
)

if user_min < exp_min:
# 1->2 is the 0/1 specialization artifact, not a user error.
if user_min == 1 and exp_min == 2:
logger.warning(
"%s Input.min_shape=1 vs exporter min=2 is the "
"PyTorch 0/1 specialization artifact; TRT engine "
"min will be 2.",
mismatch,
)
continue
raise ValueError(
f"{mismatch} Input.min_shape is below the exporter's min "
f"({user_min} < {exp_min}); TRT will reject shapes "
f"below {exp_min} at runtime.{hint}"
)

# Strict subset: engine profile widens to the exporter.
logger.warning(
"%s Input bounds are a subset of the exporter's range; "
"TRT engine profile will use the wider [%d, %d]."
" Re-export with Dim('%s', min=%d, max=%d) for a "
"narrower profile.",
mismatch,
exp_min,
exp_max,
expr,
user_min,
user_max,
)

return user_symbol_bounds


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -906,6 +1042,12 @@ def compile_module(
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}

# Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds.
# Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export.
user_symbol_bounds = _build_user_symbol_bounds(
gm, sample_arg_inputs, sample_kwarg_inputs
)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down Expand Up @@ -1087,7 +1229,9 @@ def preserve_module_specs(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
submodule_inputs = partitioning.construct_submodule_inputs(
submodule, user_symbol_bounds=user_symbol_bounds
)

assert submodule_inputs is not None

Expand Down
38 changes: 32 additions & 6 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Set, Tuple

import sympy
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily

from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import (
COMPLEX_TO_REAL_DTYPE,
Expand All @@ -20,11 +20,14 @@ def construct_dynamic_input(
input_dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`extract_var_range_info` to fill unbounded exporter uppers.
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
Expand All @@ -33,7 +36,9 @@ def construct_dynamic_input(
max_shape = []
for d, dim in enumerate(input_shape):
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
min_max_opt = extract_var_range_info(
dim, user_symbol_bounds=user_symbol_bounds
)
unwrapped_min_max_opt: Dict[str, int] = {}
if "min" not in min_max_opt or min_max_opt["min"] is None:
logger.warning(
Expand Down Expand Up @@ -85,9 +90,12 @@ def get_input(
dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs.

``user_symbol_bounds`` is forwarded to :func:`construct_dynamic_input`.
"""
if dtype in COMPLEX_TO_REAL_DTYPE:
real_dtype = COMPLEX_TO_REAL_DTYPE[dtype]
Expand All @@ -106,19 +114,25 @@ def get_input(
dtype,
name=name,
is_shape_tensor=is_shape_tensor,
user_symbol_bounds=user_symbol_bounds,
)
else:
return Input(
shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor
)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
def construct_submodule_inputs(
module: torch.fx.GraphModule,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`get_input` to fill unbounded exporter uppers.
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
Expand All @@ -134,7 +148,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
input_shape = input_meta.size()
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymInt):
# Assuming sym_integers | shape inputs always have torch.int64 dtype
Expand All @@ -144,6 +163,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.int64,
name=input.name,
is_shape_tensor=True,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymFloat):
Expand All @@ -153,6 +173,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.float32,
name=input.name,
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
user_symbol_bounds=user_symbol_bounds,
)
)
else:
Expand All @@ -164,7 +185,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
input_meta = input.meta["tensor_meta"]
input_shape = input_meta.shape
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
else:
raise AssertionError(
Expand Down
Loading
Loading