Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5773,7 +5773,11 @@ def bitwise_and(context, node):
# which torch.export emits for `tensor | tensor` / `tensor ^ tensor`. These are
# common when building boolean attention masks (e.g. Gemma combines a causal
# mask with a padding mask via __or__).
@register_torch_op(torch_alias=["or"])
#
# "ior" is the post-sanitize form of "aten::__ior__", the in-place `|=`.
# `sanitize_op_kind` only strips a trailing `_`, so the leading `i` in the
# `__i<op>__` family is preserved -- the alias has to be listed explicitly.
@register_torch_op(torch_alias=["or", "ior"])
def bitwise_or(context, node):
_bitwise_as_logical_if_boolean(context, node, "bitwise_or", logical_or)

Expand Down
35 changes: 35 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13495,6 +13495,41 @@ def forward(self, x, y):
input_as_shape=False,
)

@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
)
def test_ior_operator(self, compute_unit, backend):
# Regression test for issue #2584: TorchScript trace of `z |= y`
# records `aten::__ior__`, which sanitizes to "ior" and must be
# registered as an alias of bitwise_or. Previously gemma-3-1b-it
# conversion failed with
# "PyTorch convert function for op '__ior__' not implemented."
#
# Notes on scope:
# * Core ML inputs are immutable, so the model clones an input
# before mutating it; without the clone, `run_compare_torch`
# would raise the unrelated user-input-mutation guard.
# * torch.export decomposes `__ior__` into `clone + bitwise_or`,
# so the new alias is only reachable via the TorchScript path.
class TestModel(torch.nn.Module):
def forward(self, x, y):
z = x.clone()
z |= y
return z

input_shape = (2, 3)
input_data_x = torch.rand(*input_shape) > 0.2
input_data_y = torch.rand(*input_shape) < 0.8
self.run_compare_torch(
[input_data_x, input_data_y],
TestModel(),
frontend=TorchFrontend.TORCHSCRIPT,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
)


class TestBitwiseXor(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down