Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6120,6 +6120,64 @@ def _translate_torch_args(dim) -> Var:
context.add(res, torch_name=node.name)


@register_torch_op
def tensor_split(context, node):
def _parse_positional_args(context, node) -> Tuple[Var]:
inputs = _get_inputs(context, node, min_expected=2)
nargs = len(inputs)

x = inputs[0]
indices_or_sections = inputs[1]
dim = inputs[2] if nargs > 2 else 0
return x, indices_or_sections, dim

def _parse_keyword_args(context, node, dim) -> Var:
dim = _get_kwinputs(context, node, "dim", default=[dim])[0]
return dim

def _translate_torch_args(dim) -> Var:
if isinstance(dim, Var):
dim = dim.val
return dim

x, indices_or_sections, dim = _parse_positional_args(context, node)
dim = _parse_keyword_args(context, node, dim)
dim = _translate_torch_args(dim)

if indices_or_sections.val is None:
raise NotImplementedError(
"tensor_split is only supported for a constant indices_or_sections."
)

dim_size = x.shape[dim]
if is_symbolic(dim_size):
raise NotImplementedError(
"tensor_split is not supported when the split dimension has a symbolic size."
)

if isinstance(indices_or_sections.val, np.ndarray):
# boundaries: 0, each index (clamped into range), dim_size; sizes are the gaps between them
boundaries = [0]
for index in indices_or_sections.val:
index = int(index)
if index < 0:
index += dim_size
boundaries.append(builtins.min(builtins.max(index, 0), dim_size))
boundaries.append(dim_size)
split_sizes = [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)]
if builtins.min(split_sizes) < 0:
raise NotImplementedError(
"tensor_split with non-monotonic indices is not supported."
)
else:
sections = int(indices_or_sections.val)
chunk, remainder = divmod(dim_size, sections)
split_sizes = [chunk + 1] * remainder + [chunk] * (sections - remainder)

res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name)
context.add(res, torch_name=node.name)


@register_torch_op(torch_alias=["unbind_copy"])
def unbind(context, node):
def _parse_positional_args(context, node) -> Tuple[Var]:
Expand Down
47 changes: 47 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7342,6 +7342,53 @@ def forward(self, x):
)


class TestTensorSplit(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, sections, dim",
itertools.product(compute_units, backends, frontends, [2, 3], [0, -1]),
)
def test_tensor_split_sections(self, compute_unit, backend, frontend, sections, dim):
class TestModel(torch.nn.Module):
def forward(self, x):
return torch.tensor_split(x, sections, dim=dim)

self.run_compare_torch(
(10, 7), TestModel(), frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, indices, dim",
itertools.product(compute_units, backends, frontends, [[2, 5], [1, 3], [-2]], [0, -1]),
)
def test_tensor_split_indices(self, compute_unit, backend, frontend, indices, dim):
class TestModel(torch.nn.Module):
def forward(self, x):
return torch.tensor_split(x, indices, dim=dim)

self.run_compare_torch(
(10, 7), TestModel(), frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_tensor_split_non_monotonic_indices(self, compute_unit, backend, frontend):
if frontend in TORCH_EXPORT_BASED_FRONTENDS:
pytest.skip("torch.export decomposes tensor_split into slice ops")

class TestModel(torch.nn.Module):
def forward(self, x):
return torch.tensor_split(x, [5, 2], dim=0)

with pytest.raises(
NotImplementedError, match="non-monotonic indices is not supported"
):
self.run_compare_torch(
(10, 7), TestModel(), frontend=frontend, backend=backend, compute_unit=compute_unit
)


class TestUnbind(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, dim",
Expand Down