Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def embedding(
) -> TRTTensor:
indices_tensor = input
embedding_tensor = weight
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
raise RuntimeError(
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
)
# Note: TensorRT's Gather layer supports both int32 and int64 indices
# https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
# unsupported parameters
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,7 @@ def scatter(
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index_shape)
if index.dtype == trt.int64:
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
# Note: TensorRT's Scatter layer supports both int32 and int64 indices
dim = get_positive_dim(dim, len(input_shape))
src_tensor = src
# scatter.value
Expand Down Expand Up @@ -530,7 +529,9 @@ def gather(
) -> TRTTensor:
input_shape = input.shape
dim = get_positive_dim(dim, len(input_shape))
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
# Note: TensorRT's Gather layer supports both int32 and int64 indices
# https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html
index = get_trt_tensor(ctx, index, name + "_index_tensor")
gather_layer = ctx.net.add_gather(input, index, axis=dim)
gather_layer.mode = trt.GatherMode.ELEMENT
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
Expand Down Expand Up @@ -857,7 +858,7 @@ def index_put_converter(
values_expanded,
(-1,),
)
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
# Note: TensorRT's Scatter layer supports both int32 and int64 indices
if accumulate:
zero_tensor = impl.full.full(
ctx,
Expand Down
21 changes: 21 additions & 0 deletions tests/py/dynamo/conversion/test_embedding_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ class TestEmbeddingConverter(DispatchTestCase):
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=True,
),
# int64 indices - TensorRT now supports int64 for gather operations
param(
test_name="1d_indices_int64",
indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int64),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=False,
),
param(
test_name="2d_indices_int64",
indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int64),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=True,
),
param(
test_name="3d_indices_int64",
indices_tensor=torch.tensor(
[[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int64
),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=True,
),
]
)
def test_embedding(
Expand Down
109 changes: 109 additions & 0 deletions tests/py/dynamo/conversion/test_gather_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,112 @@ def forward(self, input, index):
input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs)


class TestGatherInt64IndexConverter(DispatchTestCase):
"""Test cases for gather with int64 indices.
TensorRT now supports int64 indices for gather operations.
https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html
"""

@parameterized.expand(
[
(
"gather_zero_dim_indexOne_int64",
0,
torch.tensor([[0, 1, 2, 0]], dtype=torch.int64),
),
(
"gather_zero_dim_indexTwo_int64",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64),
),
(
"gather_one_dim_indexOne_int64",
1,
torch.tensor([[0, 1, 2, 0]], dtype=torch.int64),
),
(
"gather_one_dim_indexTwo_int64",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64),
),
]
)
def test_gather_index_int64_constant(self, _, dim, index):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.gather.default(input, dim, index)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
(
"gather_zero_dim_indexOne_int64_input",
0,
torch.tensor([[0, 1, 2, 0]], dtype=torch.int64),
),
(
"gather_zero_dim_indexTwo_int64_input",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64),
),
(
"gather_one_dim_indexOne_int64_input",
1,
torch.tensor([[0, 1, 2, 0]], dtype=torch.int64),
),
(
"gather_one_dim_indexTwo_int64_input",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64),
),
]
)
def test_gather_index_int64_input(self, _, dim, index):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.gather.default(input, dim, index)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
(
"gather_float_input_int64_index",
0,
torch.tensor([[0, 1, 2, 0]], dtype=torch.int64),
),
(
"gather_float_input_int64_index_dim1",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64),
),
]
)
def test_gather_float_input_int64_index(self, _, dim, index):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.gather.default(input, dim, index)

input = torch.randn(3, 5, dtype=torch.float32)
inputs = [input, index]
self.run_test(TestModule(), inputs)


if __name__ == "__main__":
run_tests()
109 changes: 109 additions & 0 deletions tests/py/dynamo/conversion/test_index_select_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,114 @@ def forward(self, source_tensor, indice_tensor):
)


class TestIndexSelectInt64Converter(DispatchTestCase):
"""Test cases for index_select with int64 indices.
TensorRT now supports int64 indices for gather operations.
https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html
"""

@parameterized.expand(
[
("1d_input_int64", (10,), 0, (1,)),
("2d_input_dim_0_int64", (10, 3), 0, (0, 2)),
("2d_input_dim_1_int64", (5, 10), 1, (1, 2, 3)),
("3d_input_dim_0_int64", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2_int64", (10, 5, 10), 2, (3, 3, 4)),
("3d_input_dim_-1_int64", (10, 5, 10), -1, (3, 3, 4)),
]
)
def test_index_select_int64(self, _, source_shape, dim, indices_val):
class TestIndexSelect(torch.nn.Module):
def forward(self, source_tensor, indices_tensor):
return torch.ops.aten.index_select.default(
source_tensor, dim, indices_tensor
)

input = [
torch.randn(*source_shape, dtype=torch.float32),
torch.tensor([*indices_val], dtype=torch.int64),
]

self.run_test(
TestIndexSelect(),
input,
)

@parameterized.expand(
[
param(
# 1d_source_tensor_int64_index
source_tensor=torch.randn((3,), dtype=torch.float32),
source_tensor_1=torch.randn((5,), dtype=torch.float32),
dynamic_shapes={
"source_tensor": {0: torch.export.Dim("dyn_dim", min=3, max=6)},
"indice_tensor": {},
},
dim=0,
indice_tensor=torch.tensor(
[
1,
],
dtype=torch.int64,
),
),
param(
# 2d_source_tensor_int64_index
source_tensor=torch.randn((3, 3), dtype=torch.float32),
source_tensor_1=torch.randn((4, 6), dtype=torch.float32),
dynamic_shapes={
"source_tensor": {
0: torch.export.Dim("dyn_dim1", min=3, max=6),
1: torch.export.Dim("dyn_dim2", min=2, max=7),
},
"indice_tensor": {},
},
dim=-1,
indice_tensor=torch.tensor([0, 2], dtype=torch.int64),
),
]
)
def test_index_select_int64_dynamic_shape(
self, source_tensor, source_tensor_1, dynamic_shapes, dim, indice_tensor
):
class IndexSelect(torch.nn.Module):
def forward(self, source_tensor, indice_tensor):
return torch.ops.aten.index_select.default(
source_tensor,
dim,
indice_tensor,
)

inputs = (source_tensor, indice_tensor)
mod = IndexSelect()

fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes)
trt_mod = torch_tensorrt.dynamo.compile(
fx_mod,
inputs=inputs,
enable_precisions=torch.float32,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
# use different shape of inputs for inference:
inputs = (source_tensor_1, indice_tensor)
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
ref_outputs = mod(*cuda_inputs)
outputs = trt_mod(*cuda_inputs)
for out, ref in zip(outputs, ref_outputs):
torch.testing.assert_close(
out,
ref,
rtol=RTOL,
atol=ATOL,
equal_nan=True,
check_dtype=True,
)


if __name__ == "__main__":
run_tests()
Loading
Loading