Skip to content

Commit 99f1f0b

Browse files
authored
Add MLX integer support for aten.bitwise_and (pytorch#18979)
### Summary Fixes pytorch#18925 `aten.bitwise_and` currently goes through the bool-only `LogicalAndNode` path in the MLX delegate, which means integer tensors do not lower correctly. This switches `aten.bitwise_and` to a dedicated `BitwiseAndNode` while keeping `aten.logical_and` on the existing logical path. This change: - adds `BitwiseAndNode` to the MLX schema - adds `exec_bitwise_and()` to the MLX interpreter runtime - registers `aten.bitwise_and.Tensor` and `aten.bitwise_and.Scalar` as table-driven MLX binary ops - keeps `aten.logical_and` on the existing bool-only logical path - adds MLX op tests for bool, integer, and scalar `bitwise_and` ### Test plan - `python3 -m py_compile backends/mlx/ops.py backends/mlx/test/test_ops.py` - `python3 backends/mlx/serialization/generate.py` - `PYTHONPATH=src python3 -m executorch.backends.mlx.test.run_all_tests --list | rg 'bitwise_and'` - in a local dev env, this registers: - `bitwise_and_bool` - `bitwise_and_int` - `bitwise_and_scalar` cc @metascroy
1 parent bd5752a commit 99f1f0b

4 files changed

Lines changed: 75 additions & 29 deletions

File tree

backends/mlx/ops.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
AsStridedNode,
5151
AsTypeNode,
5252
Atan2Node,
53+
BitwiseAndNode,
5354
BitwiseInvertNode,
5455
BroadcastToNode,
5556
CeilNode,
@@ -481,7 +482,14 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
481482
([torch.ops.aten.minimum.default], MinimumNode, "aten.minimum", False),
482483
([torch.ops.aten.atan2.default], Atan2Node, "aten.atan2", False),
483484
([torch.ops.aten.logaddexp.default], LogAddExpNode, "aten.logaddexp", False),
485+
([torch.ops.aten.logical_and.default], LogicalAndNode, "aten.logical_and", False),
484486
([torch.ops.aten.logical_or.default], LogicalOrNode, "aten.logical_or", False),
487+
(
488+
[torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_and.Scalar],
489+
BitwiseAndNode,
490+
"aten.bitwise_and",
491+
True,
492+
),
485493
(
486494
[torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar],
487495
LessNode,
@@ -3143,34 +3151,6 @@ def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot:
31433151
return out
31443152

31453153

3146-
@REGISTRY.register(
3147-
target=[torch.ops.aten.logical_and.default, torch.ops.aten.bitwise_and.Tensor]
3148-
)
3149-
def _logical_and_handler(P: MLXProgramBuilder, n: Node) -> Slot:
3150-
"""Handle aten.logical_and / aten.bitwise_and on bool tensors."""
3151-
args = P.args(n)
3152-
require_args(args, 2, 2, "aten.logical_and/bitwise_and")
3153-
require_kwargs(P.kwargs(n), set(), "aten.logical_and/bitwise_and")
3154-
3155-
# bitwise_and is only equivalent to logical_and for bool tensors.
3156-
if n.target == torch.ops.aten.bitwise_and.Tensor:
3157-
dtype = n.args[0].meta.get("val", None)
3158-
if dtype is not None and hasattr(dtype, "dtype") and dtype.dtype != torch.bool:
3159-
raise ValueError(
3160-
f"aten.bitwise_and on non-bool dtype {dtype.dtype} is not supported; "
3161-
"only bool tensors can be lowered via LogicalAndNode"
3162-
)
3163-
out = P.make_or_get_slot(n)
3164-
P.emit(
3165-
LogicalAndNode(
3166-
a=P.slot_to_tid(args[0]),
3167-
b=P.slot_to_tid(args[1]),
3168-
out=P.slot_to_tid(out),
3169-
)
3170-
)
3171-
return out
3172-
3173-
31743154
@REGISTRY.register(target=[torch.ops.aten.scalar_tensor.default])
31753155
def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot:
31763156
"""This is equivalent to torch.full([], scalar, dtype=dtype)."""

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,15 @@ exec_logical_or(const LogicalOrNode& n, ExecutionState& st, StreamOrDevice s) {
14021402
n.out, logical_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
14031403
}
14041404

1405+
inline void exec_bitwise_and(
1406+
const BitwiseAndNode& n,
1407+
ExecutionState& st,
1408+
StreamOrDevice s) {
1409+
st.set_tensor(
1410+
n.out,
1411+
bitwise_and(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
1412+
}
1413+
14051414
inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) {
14061415
int rows = resolve_int(n.n, st);
14071416
int cols = resolve_int(n.m, st);
@@ -2052,6 +2061,9 @@ class Interpreter {
20522061
case OpCode::LOGICAL_OR:
20532062
ops::exec_logical_or(std::get<LogicalOrNode>(instr.node), st, s);
20542063
break;
2064+
case OpCode::BITWISE_AND:
2065+
ops::exec_bitwise_and(std::get<BitwiseAndNode>(instr.node), st, s);
2066+
break;
20552067
case OpCode::TRI:
20562068
ops::exec_tri(std::get<TriNode>(instr.node), st, s);
20572069
break;

backends/mlx/serialization/schema.fbs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,12 @@ table LogicalOrNode {
579579
out: Tid (required);
580580
}
581581

582+
table BitwiseAndNode {
583+
a: Tid (required);
584+
b: Tid (required);
585+
out: Tid (required);
586+
}
587+
582588
// Triangular matrix ops
583589
table TriNode {
584590
out: Tid (required);
@@ -1130,7 +1136,8 @@ union OpNode {
11301136
ScanNode,
11311137
MetalKernelNode,
11321138
BitwiseInvertNode,
1133-
RollNode
1139+
RollNode,
1140+
BitwiseAndNode
11341141
// BC: Add new op nodes here (append only)
11351142
}
11361143

backends/mlx/test/test_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4274,6 +4274,8 @@ def create_model(self) -> nn.Module:
42744274
{"op_name": "equal", "op_fn": torch.eq, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]},
42754275
{"op_name": "not_equal", "op_fn": torch.ne, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]},
42764276
# logical
4277+
{"op_name": "bitwise_and_bool", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
4278+
{"op_name": "bitwise_and_int", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},
42774279
{"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
42784280
{"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
42794281
]
@@ -4286,6 +4288,51 @@ def create_model(self) -> nn.Module:
42864288
globals()[_cls.__name__] = _cls
42874289

42884290

4291+
class BitwiseAndScalarModel(nn.Module):
4292+
def __init__(self, scalar):
4293+
super().__init__()
4294+
self.scalar = scalar
4295+
4296+
def forward(self, a: torch.Tensor) -> torch.Tensor:
4297+
return torch.bitwise_and(a, self.scalar)
4298+
4299+
4300+
@register_test
4301+
class BitwiseAndScalarTest(OpTestCase):
4302+
"""Test case for aten.bitwise_and op (Tensor_Scalar variant)."""
4303+
4304+
name = "bitwise_and_scalar"
4305+
4306+
def __init__(
4307+
self,
4308+
shape: Tuple[int, ...],
4309+
dtype: torch.dtype,
4310+
scalar,
4311+
):
4312+
self.shape = shape
4313+
self.dtype = dtype
4314+
self.scalar = scalar
4315+
shape_str = "x".join(str(s) for s in shape)
4316+
dtype_str = str(dtype).replace("torch.", "")
4317+
self.name = f"bitwise_and_scalar_{shape_str}_{dtype_str}"
4318+
4319+
@classmethod
4320+
def get_test_configs(cls) -> List["BitwiseAndScalarTest"]:
4321+
return [
4322+
cls(shape=(16,), dtype=torch.bool, scalar=True),
4323+
cls(shape=(4, 4), dtype=torch.int32, scalar=7),
4324+
cls(shape=(2, 3, 4), dtype=torch.int64, scalar=13),
4325+
]
4326+
4327+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
4328+
if self.dtype == torch.bool:
4329+
return _bool_input_fn()(self.shape, self.dtype)
4330+
return _int_input_fn(0, 256)(self.shape, self.dtype)
4331+
4332+
def create_model(self) -> nn.Module:
4333+
return BitwiseAndScalarModel(self.scalar)
4334+
4335+
42894336
@register_test
42904337
class PowerScalarTest(OpTestCase):
42914338
"""Test case for aten.pow op (Tensor_Scalar variant)."""

0 commit comments

Comments
 (0)