Skip to content

Commit afd32cc

Browse files
Add a16w8 per-op test for bmm (pytorch#19599)
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.bmm` on Ethos-U55 and Ethos-U85. ## Context Batch matrix multiply (`bmm`) implements the core `Q @ K^T` and `attn_weights @ V` operations in the multi-head attention of the EMG2Pose Conformer. At int16 IO precision the accumulator width and rescale path differ between U55 and U85, so dedicated per-op coverage is needed to catch numerics divergence before it surfaces as an end-to-end SNR regression. The test matrix includes square, rectangular, and large-batch configurations to exercise different tiling strategies in the Vela backend. Also removes unused `aten_op_mm` / `exir_op_mm` variables that were dead code in `test_bmm.py`. ## Changes - Add `a16w8_bmm_test_parameters` dict with 5 test configurations covering same-shape, different-shape, rectangular, batch-10, and negative-value tensors - Add `test_bmm_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True, qtol=128, epsilon=2**-16` - Add `test_bmm_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Remove unused `aten_op_mm` and `exir_op_mm` variables - Register `ops/test_bmm.py` in `fbcode/` and `xplat/` `targets.bzl` bypass-pytorch-oss-checks Reviewed By: Ninja91 Differential Revision: D104532363
1 parent 41a38d8 commit afd32cc

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

backends/arm/test/ops/test_bmm.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,11 +10,12 @@
1010

1111
import torch
1212

13+
from executorch.backends.arm.quantizer import get_symmetric_a16w8_quantization_config
1314
from executorch.backends.arm.test import common
14-
1515
from executorch.backends.arm.test.tester.test_pipeline import (
1616
EthosU55PipelineINT,
1717
EthosU85PipelineINT,
18+
OpNotSupportedPipeline,
1819
TosaPipelineFP,
1920
TosaPipelineINT,
2021
VgfPipeline,
@@ -23,9 +24,6 @@
2324
aten_op_bmm = "torch.ops.aten.bmm.default"
2425
exir_op_bmm = "executorch_exir_dialects_edge__ops_aten_bmm_default"
2526

26-
aten_op_mm = "torch.ops.aten.matmul.default"
27-
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
28-
2927
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
3028

3129

@@ -191,3 +189,52 @@ def test_bmm_vgf_quant_single_input(test_data: input_t1):
191189
quantize=True,
192190
)
193191
pipeline.run()
192+
193+
194+
a16w8_bmm_test_parameters = {
195+
"rand_same": lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
196+
"rand_diff": lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
197+
"rand_rect": lambda: (torch.rand(1, 55, 3), torch.rand(1, 3, 44)),
198+
"rand_batch10": lambda: (torch.rand(10, 1, 10), torch.rand(10, 10, 5)),
199+
"rand_neg": lambda: (
200+
-10 * torch.randn(2, 32, 64),
201+
5 + 5 * torch.randn(2, 64, 32),
202+
),
203+
}
204+
205+
206+
@common.parametrize("test_data", a16w8_bmm_test_parameters)
207+
@common.XfailIfNoCorstone300
208+
def test_bmm_a16w8_u55_INT(test_data: input_t1):
209+
"""U55 does not support bmm with INT16 inputs.
210+
211+
Verify bmm is rejected.
212+
213+
"""
214+
pipeline = OpNotSupportedPipeline[input_t1](
215+
BMM(),
216+
test_data(),
217+
non_delegated_ops={exir_op_bmm: 1},
218+
n_expected_delegates=0,
219+
u55_subset=True,
220+
quantize=True,
221+
tosa_extensions=["int16"],
222+
)
223+
pipeline.quantizer.set_global(get_symmetric_a16w8_quantization_config())
224+
pipeline.run()
225+
226+
227+
@common.parametrize("test_data", a16w8_bmm_test_parameters)
228+
@common.XfailIfNoCorstone320
229+
def test_bmm_a16w8_u85_INT(test_data: input_t1):
230+
pipeline = EthosU85PipelineINT[input_t1](
231+
BMM(),
232+
test_data(),
233+
aten_op_bmm,
234+
exir_op_bmm,
235+
a16w8_quantization=True,
236+
symmetric_io_quantization=True,
237+
qtol=1,
238+
epsilon=2**-16,
239+
)
240+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def define_arm_tests():
4242
"ops/test_var.py",
4343
"ops/test_conv1d.py",
4444
"ops/test_gelu.py",
45+
"ops/test_bmm.py",
4546
]
4647

4748
# Quantization

0 commit comments

Comments
 (0)