Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
Expand All @@ -28,6 +27,8 @@
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM

from tensorrt import ITensor as TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -577,6 +578,16 @@ def index_has_bool_indices(
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is None and ind.op == "get_attr":
# fx.symbolic_trace embeds constant tensors as get_attr nodes
# without meta["val"]; fetch the actual tensor from the module.
try:
attr = ind.graph.owning_module
for part in ind.target.split("."):
attr = getattr(attr, part)
val = attr
except AttributeError:
pass
if val is not None and val.dtype == torch.bool:
return True
return False
Expand Down
5 changes: 2 additions & 3 deletions tests/py/dynamo/conversion/test_index_bool_split_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
3. Boolean-indexed `aten.index.Tensor` routes to the converter WITH output allocator.
4. Both paths produce correct results.
"""

import unittest
from unittest.mock import MagicMock

Expand Down Expand Up @@ -60,9 +61,7 @@ def test_none_with_bool_indices_returns_true(self):

def test_mixed_int_and_bool_returns_true(self):
"""If any index is bool, the function should return True."""
node = _make_index_node(
[torch.tensor([0, 1]), torch.tensor([True, False])]
)
node = _make_index_node([torch.tensor([0, 1]), torch.tensor([True, False])])
self.assertTrue(index_has_bool_indices(node))

def test_all_none_returns_false(self):
Expand Down
3 changes: 0 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading