Skip to content

Commit 0cb9f38

Browse files
rascaniclaude
andauthored
Cortex-M: Fix pad op to support channels_last memory format (pytorch#18429)
### Summary Fix pad_meta to propagate channels_last from input to output tensor. Fix pad_out (C++) to use dim_order() to permute logical dims and padding into physical memory order for arm_pad_s8. Add channels_last test cases to test_pad. ### Test Plan ``` pytest backends/cortex_m/test/ops/test_pad.py ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 7c79395 commit 0cb9f38

4 files changed

Lines changed: 53 additions & 6 deletions

File tree

backends/cortex_m/ops/op_pad.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,19 @@ Tensor& pad_out(
4848
return out;
4949
}
5050

51+
// Permute logical sizes to physical memory order.
52+
// Padding is already in physical order from the AOT pass.
53+
constexpr size_t kNhwcDimOrder[] = {0, 2, 3, 1};
5154
const size_t offset = kMaxSupportedDims - rank;
55+
const bool nhwc = is_channels_last_tensor(input);
5256

53-
cmsis_nn_dims input_dims = {1, 1, 1, 1};
54-
int32_t* d = &input_dims.n;
57+
int32_t dims[kMaxSupportedDims] = {1, 1, 1, 1};
5558
for (size_t i = 0; i < rank; ++i) {
56-
d[offset + i] = static_cast<int32_t>(input.size(i));
59+
const size_t src = nhwc ? kNhwcDimOrder[offset + i] : i;
60+
dims[offset + i] = static_cast<int32_t>(input.size(src));
5761
}
5862

63+
cmsis_nn_dims input_dims = {dims[0], dims[1], dims[2], dims[3]};
5964
cmsis_nn_dims cmsis_pre_pad = {
6065
static_cast<int32_t>(pre_pad[0]),
6166
static_cast<int32_t>(pre_pad[1]),

backends/cortex_m/ops/operators.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.backends.cortex_m.passes.passes_utils import (
1515
dequantize_per_tensor_cmsis,
1616
is_channel_broadcast,
17+
is_channels_last,
1718
quantize_per_tensor_cmsis,
1819
requantize_cmsis,
1920
SHIFT_INT8,
@@ -564,6 +565,16 @@ def transpose_impl(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
564565
)
565566

566567

568+
_NHWC_INV_ORDER = [0, 3, 1, 2]
569+
570+
571+
def _pad_to_logical_order(physical_pad: list[int], input: torch.Tensor) -> list[int]:
572+
"""Inverse of _to_physical_order: map physical-order padding back to logical."""
573+
if not is_channels_last(input):
574+
return list(physical_pad)
575+
return [physical_pad[_NHWC_INV_ORDER[i]] for i in range(4)]
576+
577+
567578
@register_fake("cortex_m::pad") # type: ignore[misc]
568579
def pad_meta(
569580
input: torch.Tensor,
@@ -573,10 +584,16 @@ def pad_meta(
573584
) -> torch.Tensor:
574585
rank = input.dim()
575586
offset = 4 - rank
587+
logical_pre = _pad_to_logical_order(pre_pad, input)
588+
logical_post = _pad_to_logical_order(post_pad, input)
589+
576590
output_shape = list(input.shape)
577591
for i in range(rank):
578-
output_shape[i] += pre_pad[offset + i] + post_pad[offset + i]
579-
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
592+
output_shape[i] += logical_pre[offset + i] + logical_post[offset + i]
593+
result = torch.empty(output_shape, dtype=input.dtype, device=input.device)
594+
if is_channels_last(input):
595+
result = result.to(memory_format=torch.channels_last)
596+
return result
580597

581598

582599
@impl(lib, "pad", "CompositeExplicitAutograd") # type: ignore[misc]
@@ -588,9 +605,12 @@ def pad_impl(
588605
) -> torch.Tensor:
589606
rank = input.dim()
590607
offset = 4 - rank
608+
logical_pre = _pad_to_logical_order(pre_pad, input)
609+
logical_post = _pad_to_logical_order(post_pad, input)
610+
591611
padding = []
592612
for i in reversed(range(rank)):
593-
padding.extend([pre_pad[offset + i], post_pad[offset + i]])
613+
padding.extend([logical_pre[offset + i], logical_post[offset + i]])
594614
return F.pad(input, padding, mode="constant", value=pad_value)
595615

596616

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.cortex_m.passes.passes_utils import (
13+
is_channels_last,
1314
quantize_multiplier_aot,
1415
quantize_val,
1516
SHIFT_INT8,
@@ -37,6 +38,14 @@ class QuantizedOpFusionPass(ExportPass):
3738

3839
_SOFTMAX_INPUT_INTEGER_BITS = 5
3940

41+
_NHWC_DIM_ORDER = [0, 2, 3, 1]
42+
43+
def _to_physical_order(self, logical_pad: list[int], tensor_data) -> list[int]:
44+
"""Permute a 4-element logical-dim-order list to physical memory order."""
45+
if not is_channels_last(tensor_data):
46+
return logical_pad
47+
return [logical_pad[self._NHWC_DIM_ORDER[i]] for i in range(4)]
48+
4049
def _get_add_replacement(self, args, meta):
4150
if (
4251
meta.data.get("input_qparams", {}) == {}
@@ -329,6 +338,8 @@ def _get_avg_pool2d_replacement(self, args, meta):
329338
pad_h, pad_w = padding
330339
pre_pad = [0, 0, pad_h, pad_w]
331340
post_pad = [0, 0, pad_h, pad_w]
341+
pre_pad = self._to_physical_order(pre_pad, args[0].data)
342+
post_pad = self._to_physical_order(post_pad, args[0].data)
332343
input_arg = super().call_operator(
333344
exir_ops.edge.cortex_m.pad.default,
334345
(input_arg, pre_pad, post_pad, int(zero_point)),
@@ -379,6 +390,9 @@ def _get_pad_replacement(self, args, meta):
379390
pre_pad[dim_4d] = int(padding[2 * i])
380391
post_pad[dim_4d] = int(padding[2 * i + 1])
381392

393+
pre_pad = self._to_physical_order(pre_pad, args[0].data)
394+
post_pad = self._to_physical_order(post_pad, args[0].data)
395+
382396
new_args = (args[0], pre_pad, post_pad, int(quantized_pad_value))
383397
return exir_ops.edge.cortex_m.pad.default, new_args
384398

backends/cortex_m/test/ops/test_pad.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def forward(self, x):
6969
CortexMPad((0, 0, 0, 0)),
7070
(ramp_tensor(-0.5, 0.5, (2, 3, 4, 5)),),
7171
),
72+
"pad_rank4_all_dims_channels_last": McuTestCase(
73+
CortexMPad((1, 1, 2, 2, 1, 0, 0, 1)),
74+
(ramp_tensor(-0.5, 0.5, (1, 2, 3, 4)).to(memory_format=torch.channels_last),),
75+
),
76+
"pad_rank4_last_two_dims_channels_last": McuTestCase(
77+
CortexMPad((1, 2, 3, 4)),
78+
(ramp_tensor(-1.0, 1.0, (1, 3, 4, 5)).to(memory_format=torch.channels_last),),
79+
),
7280
}
7381

7482

0 commit comments

Comments
 (0)