Skip to content

Commit 92b7411

Browse files
Add a16w8 per-op test for split (pytorch#19600)
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.split` on Ethos-U55 and Ethos-U85. ## Changes - Add `a16w8_split_test_parameters` dict with 3 test configurations covering 1D, 2D, and 3D splits along different axes - Add `test_split_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True` - Add `test_split_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Register `ops/test_split.py` in `fbcode/` and `xplat/` `targets.bzl` Differential Revision: D104533281
1 parent afd32cc commit 92b7411

3 files changed

Lines changed: 41 additions & 0 deletions

File tree

backends/arm/test/ops/test_split.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,38 @@ def test_split_tensor_vgf_quant(test_data: Tuple):
310310
quantize=True,
311311
)
312312
pipeline.run()
313+
314+
315+
a16w8_split_test_parameters = {
316+
"a16w8_1d_split_2": lambda: (torch.rand(10), 2, 0),
317+
"a16w8_2d_split_4": lambda: (torch.rand(8, 4), 4, 0),
318+
"a16w8_3d_split_4": lambda: (torch.rand(4, 4, 8), 4, 2),
319+
}
320+
321+
322+
@common.parametrize("test_data", a16w8_split_test_parameters)
323+
@common.XfailIfNoCorstone300
324+
def test_split_a16w8_u55_INT(test_data: input_t1):
325+
pipeline = EthosU55PipelineINT[input_t1](
326+
Split(),
327+
test_data(),
328+
aten_ops=[],
329+
exir_ops=exir_op,
330+
a16w8_quantization=True,
331+
symmetric_io_quantization=True,
332+
)
333+
pipeline.run()
334+
335+
336+
@common.parametrize("test_data", a16w8_split_test_parameters)
337+
@common.XfailIfNoCorstone320
338+
def test_split_a16w8_u85_INT(test_data: input_t1):
339+
pipeline = EthosU85PipelineINT[input_t1](
340+
Split(),
341+
test_data(),
342+
aten_ops=[],
343+
exir_ops=exir_op,
344+
a16w8_quantization=True,
345+
symmetric_io_quantization=True,
346+
)
347+
pipeline.run()

backends/arm/test/ops/test_var.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class Var(torch.nn.Module):
3232
),
3333
}
3434

35+
test_parameters_ethosu = {
36+
"var_4d_keep_dim_0_correction": lambda: (torch.randn(1, 50, 10, 20), True, 0),
37+
"var_4d_keep_dim_1_correction": lambda: (torch.randn(1, 30, 15, 20), True, 1),
38+
}
39+
3540
def __init__(self, keepdim: bool = True, correction: int = 0):
3641
super().__init__()
3742
self.keepdim = keepdim

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def define_arm_tests():
4343
"ops/test_conv1d.py",
4444
"ops/test_gelu.py",
4545
"ops/test_bmm.py",
46+
"ops/test_split.py",
4647
]
4748

4849
# Quantization

0 commit comments

Comments
 (0)