Skip to content

Commit 59838fc

Browse files
authored
Arm backend: Add limited boolean mask support to index_put (pytorch#18396)
index_put with a single boolean mask and scalar values can be normalized to a where operator. This is also an important case that has shown up in models, i.e. x[mask] = 0. Since the mask is not guaranteed to be constant, require the value to be scalar, that can be broadcasted for any mask. With multiple boolean masks or mixed integer, shapes and ranks become data dependent. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 0fc4d6d commit 59838fc

7 files changed

Lines changed: 286 additions & 2 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@
126126
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
127127
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
128128
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
129+
from .normalize_index_put_bool_index_tensor_pass import ( # noqa
130+
NormalizeIndexPutBoolIndexTensorPass,
131+
)
129132
from .normalize_index_put_none_indices_pass import ( # noqa
130133
NormalizeIndexPutNoneIndicesPass,
131134
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
InsertTableOpsPass,
114114
MatchArgDtypePass,
115115
MatchArgRanksPass,
116+
NormalizeIndexPutBoolIndexTensorPass,
116117
NormalizeIndexPutNoneIndicesPass,
117118
NormalizeWhileInitialArgsPass,
118119
PromoteBoolOperandsPass,
@@ -450,6 +451,7 @@ def _tosa_pipeline(
450451
self.add_passes(
451452
[
452453
NormalizeIndexPutNoneIndicesPass(),
454+
NormalizeIndexPutBoolIndexTensorPass(),
453455
RewriteIndexPutPass(),
454456
RewriteBoolBitwiseToLogicalPass(),
455457
DecomposeRemainderPass(),
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Set, Type
6+
7+
import torch
8+
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
class NormalizeIndexPutBoolIndexTensorPass(ArmPass):
16+
"""Normalize single boolean mask index_put scalar to where.
17+
In the general case, boolean masks are complex and data dependent. The simple case
18+
x[mask] = scalar
19+
Can however be directly translated to a where operation:
20+
21+
out = index_put(destination, [mask], data, accumulate=False)
22+
becomes
23+
mask = reshape(mask, mask_shape_padded)
24+
data = reshape(data, data_shape_padded)
25+
out = where(mask, data, destination)
26+
27+
Where the padded shapes are right-padded with ones to match the rank of destination (if needed).
28+
`data` must be a scalar, to ensure data_padded can be broadcasted to any destination shape
29+
depending on the (non-constant) mask.
30+
"""
31+
32+
_passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass}
33+
34+
def __init__(self):
35+
super().__init__()
36+
self.reshape_op = exir_ops.edge.aten.view_copy.default
37+
self.where_op = exir_ops.edge.aten.where.self
38+
39+
def _is_valid_bool_mask(
40+
self,
41+
indices_tensor_list,
42+
data,
43+
accumulate: bool,
44+
) -> bool:
45+
46+
indices = indices_tensor_list[0]
47+
if indices is None or indices.data.dtype != torch.bool:
48+
return False
49+
50+
# We have a boolean mask, validate that the args are supported.
51+
if accumulate or len(indices_tensor_list) != 1 or data.data.numel() != 1:
52+
raise RuntimeError(
53+
f"Got unsupported args for bool mask index_put: {accumulate=}, num indices={len(indices_tensor_list)}!=1, data shape {data.data.shape} not scalar.\n"
54+
"This is a bug, the operator should not have been delegated."
55+
)
56+
57+
return True
58+
59+
def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
60+
if op not in (exir_ops.edge.aten.index_put.default,):
61+
return super().call_operator(op, args, kwargs, meta, updated)
62+
63+
destination, indices_tensor_list, data = args[:3]
64+
accumulate = len(args) > 3 and bool(args[3])
65+
indices_tensor_list = list(indices_tensor_list)
66+
if not self._is_valid_bool_mask(indices_tensor_list, data, accumulate):
67+
return super().call_operator(op, args, kwargs, meta, updated)
68+
69+
mask = indices_tensor_list[0]
70+
destination_shape = tuple(destination.data.shape)
71+
mask_shape = tuple(mask.data.shape)
72+
padded_mask_shape = (
73+
*mask_shape,
74+
*([1] * (len(destination_shape) - len(mask_shape))),
75+
)
76+
77+
if len(mask_shape) < len(destination_shape):
78+
mask = super().call_operator(
79+
self.reshape_op,
80+
(mask, padded_mask_shape),
81+
{},
82+
meta,
83+
True,
84+
)
85+
86+
if len(destination_shape) != len(data.data.shape):
87+
data = super().call_operator(
88+
self.reshape_op,
89+
(data, [1] * len(destination_shape)),
90+
{},
91+
meta,
92+
True,
93+
)
94+
95+
return super().call_operator(
96+
self.where_op,
97+
(mask, data, destination),
98+
kwargs,
99+
meta,
100+
True,
101+
)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
embedding_support,
1414
ethos_u55_support,
1515
gather_support,
16+
index_put_support,
1617
index_select_support,
1718
index_tensor_support,
1819
minmax_support,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.index_put``."""
6+
7+
from typing import cast
8+
9+
import torch
10+
import torch.fx as fx
11+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
12+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
13+
register_tosa_support_check,
14+
SupportedTOSAOperatorCheck,
15+
)
16+
17+
from executorch.backends.arm.tosa import TosaSpecification
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
@register_tosa_support_check
22+
class IndexPutSupported(SupportedTOSAOperatorCheck):
23+
"""Reject unsupported ``index_put`` cases.
24+
25+
Explicit integer indices are fully supported.
26+
27+
For boolean mask, there are limitations:
28+
- boolean index cases only supports one bool index
29+
- boolean index cases must use a scalar ``values`` tensor
30+
- boolean index cases don't support accumulate = True.
31+
32+
"""
33+
34+
targets = [exir_ops.edge.aten.index_put.default]
35+
36+
def is_node_tosa_supported(
37+
self, node: fx.Node, tosa_spec: TosaSpecification
38+
) -> bool:
39+
indices_tensors = cast(list[fx.Node], node.args[1])
40+
41+
# None indexes mean "select whole dim", we can handle that.
42+
explicit_indices = [index for index in indices_tensors if index is not None]
43+
has_bool_index = any(
44+
get_first_fake_tensor(index).dtype == torch.bool
45+
for index in explicit_indices
46+
)
47+
has_non_bool_index = any(
48+
get_first_fake_tensor(index).dtype != torch.bool
49+
for index in explicit_indices
50+
)
51+
52+
if has_bool_index and has_non_bool_index:
53+
self.reporter.report_reject(
54+
node,
55+
(
56+
"Mixed boolean mask and integer indices in "
57+
"index_put are not supported."
58+
),
59+
)
60+
return False
61+
62+
if has_bool_index and len(explicit_indices) != 1:
63+
self.reporter.report_reject(
64+
node,
65+
"Boolean mask index_put only supports a single explicit bool index.",
66+
)
67+
return False
68+
69+
if has_bool_index:
70+
values = cast(fx.Node, node.args[2])
71+
values_tensor = get_first_fake_tensor(values)
72+
if values_tensor.numel() != 1:
73+
self.reporter.report_reject(
74+
node,
75+
"Boolean mask index_put only supports scalar values.",
76+
)
77+
return False
78+
79+
if len(node.args) > 3 and node.args[3]:
80+
self.reporter.report_reject(
81+
node,
82+
"Bool-mask index_put not supported with accumulate = True.",
83+
)
84+
return False
85+
86+
return True

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@
124124
exir_ops.edge.aten.bitwise_not.default,
125125
exir_ops.edge.aten.copy.default,
126126
exir_ops.edge.aten.tan.default,
127-
exir_ops.edge.aten.index_put.default,
128127
exir_ops.edge.aten.silu.default,
129128
exir_ops.edge.aten.detach_copy.default,
130129
}
@@ -249,7 +248,6 @@
249248
exir_ops.edge.aten.copy.default,
250249
exir_ops.edge.aten.floor_divide.default,
251250
exir_ops.edge.aten.tan.default,
252-
exir_ops.edge.aten.index_put.default,
253251
exir_ops.edge.aten.detach_copy.default,
254252
}
255253

backends/arm/test/ops/test_index_put.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@
156156
),
157157
0,
158158
),
159+
"bool_mask_scalar": (
160+
lambda: (
161+
torch.randn((2, 3, 4), dtype=torch.float32),
162+
(
163+
torch.arange(3).expand(2, 3)
164+
>= torch.tensor([3, 2], dtype=torch.int64)[:, None],
165+
),
166+
torch.tensor(0.0, dtype=torch.float32),
167+
False,
168+
),
169+
0,
170+
),
159171
"none_indices": (
160172
lambda: (
161173
torch.ones((5, 3, 2, 2), dtype=torch.float32),
@@ -210,6 +222,62 @@
210222
),
211223
0,
212224
),
225+
"none_and_bool_indices_scalar": (
226+
lambda: (
227+
torch.randn((2, 3, 4), dtype=torch.float32),
228+
(None, torch.tensor([True, False, True]), None),
229+
torch.tensor(0.0, dtype=torch.float32),
230+
False,
231+
),
232+
0,
233+
),
234+
}
235+
mixed_indices_not_supported = {
236+
"bool_and_tensor_indices_scalar": (
237+
lambda: (
238+
torch.randn((2, 3, 4), dtype=torch.float32),
239+
(
240+
torch.tensor([True, False]),
241+
torch.tensor([1, 2], dtype=torch.int64),
242+
),
243+
torch.tensor(0.0, dtype=torch.float32),
244+
False,
245+
),
246+
0,
247+
),
248+
"bool_mask_tensor": (
249+
lambda: (
250+
torch.randn((2, 3, 4), dtype=torch.float32),
251+
(torch.tensor([True, False]),),
252+
torch.rand((1, 3, 4), dtype=torch.float32),
253+
False,
254+
),
255+
0,
256+
),
257+
"two_bool_mask_scalar": (
258+
lambda: (
259+
torch.randn((2, 3, 4), dtype=torch.float32),
260+
(
261+
torch.tensor([False, True]),
262+
torch.tensor([True, False, False]),
263+
),
264+
torch.tensor(0.0, dtype=torch.float32),
265+
False,
266+
),
267+
0,
268+
),
269+
"two_bool_mask_tensor": (
270+
lambda: (
271+
torch.randn((2, 3, 4), dtype=torch.float32),
272+
(
273+
torch.tensor([False, True]),
274+
torch.tensor([True, False, False]),
275+
),
276+
torch.rand((1, 4), dtype=torch.float32),
277+
False,
278+
),
279+
0,
280+
),
213281
}
214282
test_data_int = {
215283
"rank3_zeros_int8": (
@@ -385,3 +453,28 @@ def test_index_put_vgf_quant(test_module: input_t):
385453
exir_op=IndexPut.exir_op,
386454
)
387455
pipeline.run()
456+
457+
458+
@common.parametrize("test_module", mixed_indices_not_supported)
459+
def test_index_put_tosa_FP_not_delegated(test_module: input_t):
460+
pipeline = OpNotSupportedPipeline[input_t](
461+
IndexPut(),
462+
test_module[0](),
463+
{IndexPut.exir_op: 1},
464+
quantize=False,
465+
u55_subset=False,
466+
n_expected_delegates=0,
467+
)
468+
pipeline.run()
469+
470+
471+
@common.parametrize("test_module", mixed_indices_not_supported)
472+
def test_index_put_tosa_INT_not_delegated(test_module: input_t):
473+
pipeline = OpNotSupportedPipeline[input_t](
474+
IndexPut(),
475+
test_module[0](),
476+
{IndexPut.exir_op: 1},
477+
quantize=True,
478+
n_expected_delegates=0,
479+
)
480+
pipeline.run()

0 commit comments

Comments
 (0)