Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,23 @@ def aten_binary_cross_entropy_with_logits(
raise NotImplementedError()


@torch_op("aten::bincount", trace_only=True)
def aten_bincount(
self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0
self: INT64, weights: Optional[TensorType] = None, minlength: int = 0
) -> TensorType:
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"""
if weights is not None:
raise NotImplementedError("aten::bincount with weights is not supported.")

raise NotImplementedError()
axis_0 = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))
Comment on lines +1266 to +1270

one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
return op.ReduceSum(one_hot, axis_0, keepdims=0)


def aten_binomial(
Expand Down Expand Up @@ -4976,8 +4987,14 @@ def is_advanced_index(index):
# will invalidate equality-based check.
first_shape = indices[advanced_indices[0]].shape

def same_shape(other_shape: ir.Shape) -> bool:
return (not any(d is None for d in other_shape)) and other_shape == first_shape
def same_shape(other_shape: Optional[ir.Shape]) -> bool:
return (
first_shape is not None
and other_shape is not None
and not any(d is None for d in first_shape)
and not any(d is None for d in other_shape)
and other_shape == first_shape
)

all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices)
if not all_same_shape:
Expand Down Expand Up @@ -5071,24 +5088,65 @@ def same_shape(other_shape: ir.Shape) -> bool:

def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
indices: Sequence[Optional[Union[INT64, BOOL]]],
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
# Reshape indices so it can be properly broadcasted
self_rank = len(self.shape)
index_rank = len(index.shape)
if self_rank > index_rank:
index_shape = op.Shape(index)
padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
padded_shape = op.Concat(index_shape, padding, axis=0)
index = op.Reshape(index, padded_shape)
return op.Where(index, values, self)
bool_mask = indices[0]
if len(indices) > 1:
if any(index is None for index in indices):
raise NotImplementedError(
"Boolean index_put with multiple indices does not support None indices."
)

advanced_indices = []
selected_positions = []
minus_one = op.Constant(value_ints=[-1])
for index in indices:
if index.dtype != BOOL.dtype or len(index.shape) != 1:
raise NotImplementedError(
"Boolean index_put with multiple indices supports only 1-D boolean masks."
)
positions = op.Reshape(op.Transpose(op.NonZero(index), perm=[1, 0]), minus_one)
selected_positions.append(positions)
advanced_indices.append(op.Unsqueeze(positions, minus_one))
onnx_index = op.Concat(*advanced_indices, axis=-1)
target_shape = op.Concat(
op.Shape(selected_positions[0]),
op.Slice(op.Shape(self), starts=[len(indices)], ends=[len(self.shape)], axes=[0]),
axis=0,
)
expanded_values = op.Expand(values, target_shape)
return op.ScatterND(
self, onnx_index, expanded_values, reduction="add" if accumulate else None
)

del accumulate # Boolean masks index each position at most once.

if bool_mask is None or bool_mask.dtype != BOOL.dtype:
raise NotImplementedError(
"Boolean index_put expects a boolean mask as the first index."
)

for _ in range(len(self.shape) - len(bool_mask.shape)):
bool_mask = op.Unsqueeze(bool_mask, op.Constant(value_ints=[-1]))

expanded_mask = op.Expand(bool_mask, op.Shape(self))
flat_mask = op.Reshape(expanded_mask, op.Constant(value_ints=[-1]))
flat_mask_int = op.Cast(flat_mask, to=INT64.dtype)
positions = op.Clip(
op.Sub(
op.CumSum(flat_mask_int, op.Constant(value_ints=[0])), op.Constant(value_ints=[1])
),
op.Constant(value_ints=[0]),
)
flat_values = op.Reshape(values, op.Constant(value_ints=[-1]))
gathered_values = op.Gather(flat_values, positions)
flat_self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.Where(flat_mask, gathered_values, flat_self)
return op.Reshape(result, op.Shape(self))


def aten_index_reduce(
Expand Down
50 changes: 50 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_bincount(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=6)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Comment on lines +87 to +94
def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -902,6 +915,43 @@ def forward(self, x, index, update):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.zeros((2, 3), dtype=torch.float32)
mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool)
update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_multi_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask0, mask1, update):
return torch.ops.aten.index_put(x, [mask0, mask1], update)

x = torch.zeros((3, 4), dtype=torch.float32)
mask0 = torch.tensor([True, False, True], dtype=torch.bool)
mask1 = torch.tensor([True, False, True, False], dtype=torch.bool)
update = torch.tensor([10.0, 20.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask0, mask1, update),
input_names=["x", "mask0", "mask1", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_std_mean(self):
"""Test torch.std_mean which will be decomposed into prims.sum."""

Expand Down