Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,38 +569,49 @@ def get_trt_tensor(


@overload
def get_positive_dim(dim: int, dim_size: int) -> int: ...
def get_positive_dim(
dim: Union[int, torch.SymInt], dim_size: Union[int, torch.SymInt]
) -> Union[int, torch.SymInt]: ...


@overload
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Sequence[Union[int, torch.SymInt]], dim_size: Union[int, torch.SymInt]
) -> Tuple[Union[int, torch.SymInt], ...]: ...


def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
dim: Union[int, torch.SymInt, Sequence[Union[int, torch.SymInt]]],
dim_size: Union[int, torch.SymInt],
) -> Union[int, torch.SymInt, Tuple[Union[int, torch.SymInt], ...]]:
"""
Given an integer number or tuple that represents dimension(s) in the array,
Given an integer/SymInt number or tuple that represents dimension(s) in the array,
transform it to a positive integer dim if it's negative.
Otherwise, truncate it to the dimension size
Args:
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
dim_size (int): The size of the dimension in the array.
dim (Union[int, torch.SymInt, Sequence[Union[int, torch.SymInt]]]):
A integer/SymInt or Sequence of integer/SymInt values that represent
dimension(s) in an array.
dim_size (Union[int, torch.SymInt]): The size of the dimension in the array.
Returns:
A positive integer or tuple of integers that represent the same dimension as the given dim.
A positive integer/SymInt or tuple of integer/SymInt values that represent
the same dimension as the given dim.
"""

def positive_dim(d: int) -> int:
if d < 0:
return d % dim_size
def positive_dim(d: Union[int, torch.SymInt]) -> Union[int, torch.SymInt]:
if isinstance(d, torch.SymInt) or isinstance(dim_size, torch.SymInt):
return torch.sym_ite(d < 0, d % dim_size, torch.sym_min(d, dim_size))
else:
return min(d, dim_size)
if d < 0:
return d % dim_size
else:
return min(d, dim_size)

return (
positive_dim(dim)
if isinstance(dim, int)
if isinstance(dim, (int, torch.SymInt))
else tuple(positive_dim(d) for d in dim)
)

Expand Down
39 changes: 23 additions & 16 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -190,28 +190,35 @@ def slice_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
start: Optional[Union[int, torch.SymInt]] = None,
end: Optional[Union[int, torch.SymInt]] = None,
step: Optional[Union[int, torch.SymInt]] = 1,
) -> torch.Tensor:
dim_size = input_tensor.shape[dim]
device_input_tensor = input_tensor.device

start = 0 if start is None else start # Ensure start is int
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None: # Ensure end is int
end = dim_size
end = (
get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end
)
if step is None:
step = 1
if start is None:
start = 0
else:
start = get_positive_dim(start, dim_size)

# step == 0 is not a valid torch case
if start == 0 and end == dim_size and step == 1:
if end is None:
end = dim_size
else:
end = get_positive_dim(end, dim_size)

# step == 0 is not a valid torch case where start, end, dim_size, and step could be symbolic
if (
isinstance(start, int)
and isinstance(end, int)
and isinstance(step, int)
and start == 0
and end == dim_size
and step == 1
):
return src_tensor

# Ensure start, end, and step are all integers
# Ensure start, end, and step are all integers or SymInts
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"
assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt"
Expand Down
117 changes: 117 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,123 @@ def forward(self, x, src):
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
)

def test_lowering_slice_scatter_dynamic_symint_start_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src):
# `start` is symbolic under torch.export dynamic shape tracing.
start = x.shape[1] - src.shape[1]
return torch.ops.aten.slice_scatter.default(x, src, 1, start, None, 1)

dim1 = torch.export.Dim("dim1", min=8, max=10)
dynamic_shapes = {
"x": [torch.export.Dim.STATIC, dim1],
"src": [torch.export.Dim.STATIC, torch.export.Dim.STATIC],
}
example_inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
exported_program = torch.export.export(
sliceScatter(), tuple(example_inputs), dynamic_shapes=dynamic_shapes
)
fx_graph = exported_program.module()

expected_ops = {torch.ops.aten.scatter.src}
unexpected_ops = {torch.ops.aten.slice_scatter.default}
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
list(example_inputs),
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

trt_model = torch_tensorrt.dynamo.compile(
exported_program,
[
torch_tensorrt.Input(
min_shape=[8, 8], opt_shape=[8, 9], max_shape=[8, 10]
),
torch_tensorrt.Input(
min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]
),
],
)
inputs = (torch.zeros(8, 10).cuda(), torch.ones(8, 2).cuda())
torch.testing.assert_close(
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
)

def test_lowering_slice_scatter_dynamic_symint_end_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
# `end` is symbolic under torch.export dynamic shape tracing.
end = x.shape[1] - 1
src = torch.zeros_like(x[:, 1:-1])
return torch.ops.aten.slice_scatter.default(x, src, 1, 1, end, 1)

dim1 = torch.export.Dim("dim1", min=8, max=10)
dynamic_shapes = {
"x": [torch.export.Dim.STATIC, dim1],
}
example_inputs = (torch.zeros(8, 8).cuda(),)
exported_program = torch.export.export(
sliceScatter(), tuple(example_inputs), dynamic_shapes=dynamic_shapes
)
fx_graph = exported_program.module()

expected_ops = {torch.ops.aten.scatter.src}
unexpected_ops = {torch.ops.aten.slice_scatter.default}
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
list(example_inputs),
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

trt_model = torch_tensorrt.dynamo.compile(
exported_program,
[
torch_tensorrt.Input(
min_shape=[8, 8], opt_shape=[8, 9], max_shape=[8, 10]
)
],
)
inputs = (torch.rand(8, 10).cuda(),)
torch.testing.assert_close(
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
)

def test_lowering_select_scatter_dimZero_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down
Loading