Skip to content

Commit a3b93f5

Browse files
3l1facebook-github-bot
authored andcommitted
Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass (#18167)
Summary: Two optimizations in ToTosaMemoryFormatPass to reduce TOSA TRANSPOSE nodes: 1. **NHWC-safe reshape detection:** When a 4D→4D view_copy has monotonic shape_indices on the raw shapes and preserves the last dimension (NHWC channel), skip inserting input/output transposes. The view_copy can operate directly on NHWC data. 2. **Redundant permute_copy elimination:** Model-level permute_copy ops whose permutation matches channels_last_order (NCHW→NHWC) or its inverse (NHWC→NCHW) are redundant with the tosa_dim_order annotation that already handles format conversion. Replace them with view_copy (identity reshape) to avoid generating TOSA TRANSPOSE nodes. Reviewed By: digantdesai Differential Revision: D96432610
1 parent 7134708 commit a3b93f5

2 files changed

Lines changed: 284 additions & 2 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
1818
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
19+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1920
from executorch.exir import ExportedProgram
2021
from executorch.exir.dialects._ops import ops as exir_ops
2122
from executorch.exir.pass_base import ExportPass, PassResult
@@ -271,7 +272,7 @@ def insert_output_transpose(node, graph_module):
271272
# Guard: mem_format must be a true permutation for the current rank
272273
assert sorted(mem_format) == list(
273274
range(rank)
274-
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
275+
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"
275276

276277
with graph_module.graph.inserting_after(node):
277278
permute_node = create_node(
@@ -296,6 +297,104 @@ def insert_output_transpose(node, graph_module):
296297
for user in users:
297298
user.replace_input_with(node, permute_node)
298299

300+
@staticmethod
301+
def _get_shape_indices(
302+
src_shape: list[int], tgt_shape: list[int]
303+
) -> list[list[int]] | None:
304+
"""Greedy dimension matching for reshape operations.
305+
306+
For each target dimension, greedily consumes contiguous source
307+
dimensions whose product equals the target size. Size-1 target
308+
dimensions that do not correspond to any source dimension produce
309+
empty index lists (inserted dims).
310+
311+
Returns ``None`` when no valid mapping exists.
312+
"""
313+
src_idx = 0
314+
result: list[list[int]] = []
315+
316+
for tgt_dim in tgt_shape:
317+
if tgt_dim <= 0:
318+
return None
319+
320+
indices: list[int] = []
321+
remaining = tgt_dim
322+
323+
while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0:
324+
indices.append(src_idx)
325+
remaining //= src_shape[src_idx]
326+
src_idx += 1
327+
if remaining == 1:
328+
break
329+
330+
if remaining != 1:
331+
return None
332+
333+
result.append(indices)
334+
335+
if src_idx != len(src_shape):
336+
return None
337+
338+
return result
339+
340+
@staticmethod
341+
def _is_monotonic(indices: list[list[int]]) -> bool:
342+
"""Return ``True`` when all non-empty index groups are strictly
343+
ordered — i.e. each group's indices follow the previous group's.
344+
"""
345+
last_max = -1
346+
for group in indices:
347+
if not group:
348+
continue
349+
if group[0] <= last_max:
350+
return False
351+
last_max = group[-1]
352+
return True
353+
354+
@staticmethod
355+
def _is_nhwc_safe_reshape(
356+
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
357+
) -> bool:
358+
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.
359+
360+
By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
361+
``meta["val"]`` are already in NHWC physical order (the channel
362+
dimension sits at position ``rank - spatial_rank - 1``, not at
363+
position 1 as in NCHW). We therefore check the shape indices on
364+
the **raw** input/output shapes — no extra permutation is needed.
365+
366+
Returns ``True`` when:
367+
1. The reshape has monotonic shape_indices (each output dim maps
368+
to a contiguous, in-order group of input dims), AND
369+
2. The channel dimension is preserved alone (not merged with
370+
spatial dims).
371+
"""
372+
rank_in = len(input_shape)
373+
rank_out = len(output_shape)
374+
if rank_in < 4 or rank_out < 4:
375+
return False
376+
377+
indices = ToTosaMemoryFormatPass._get_shape_indices(
378+
list(input_shape), list(output_shape)
379+
)
380+
if indices is None:
381+
return False
382+
383+
if not ToTosaMemoryFormatPass._is_monotonic(indices):
384+
return False
385+
386+
# In the TOSA pipeline the physical memory order is NHWC.
387+
# The channel dimension in NHWC is always the **last** axis
388+
# (position ``rank - 1``). It must appear *alone* in its
389+
# output group — if it is merged with spatial dims the reshape
390+
# would reorder channel data and the optimisation is invalid.
391+
channel_idx = rank_in - 1
392+
for group in indices:
393+
if channel_idx in group:
394+
return len(group) == 1
395+
# Channel dim not consumed by any group — conservative reject.
396+
return False
397+
299398
@staticmethod
300399
def _insert_view_transpose(
301400
input_shape, output_shape, node, input_node, graph_module
@@ -317,6 +416,14 @@ def _insert_view_transpose(
317416
output_sr,
318417
)
319418

419+
# When the NHWC-space reshape has monotonic shape_indices the
420+
# view_copy can operate directly on NHWC data — no transposes
421+
# are needed.
422+
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
423+
input_shape, output_shape, input_sr, output_sr
424+
):
425+
return
426+
320427
if (
321428
channel_reshape or nhwc_to_nchw
322429
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
@@ -345,10 +452,57 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
345452
- 1D/2D tensors
346453
347454
"""
348-
for node in graph_module.graph.nodes:
455+
for node in list(graph_module.graph.nodes):
349456
if node.op != "call_function":
350457
continue
351458

459+
# Eliminate model-level permute_copy ops that are redundant
460+
# with the tosa_dim_order annotation. When a permute_copy's
461+
# permutation matches the channels-last order (or its
462+
# inverse), the permute does the same NCHW<>NHWC conversion
463+
# that tosa_dim_order already handles -- keeping both would
464+
# double-convert. Replace with view_copy (identity reshape).
465+
if node.target in (
466+
exir_ops.edge.aten.permute_copy.default,
467+
exir_ops.edge.aten.permute.default,
468+
):
469+
perm = list(node.args[1])
470+
rank = len(perm)
471+
sr = node.meta.get("tosa_spatial_rank", 0)
472+
473+
if rank >= 3 and sr >= 1:
474+
cl_order = list(
475+
self._channels_last_order(rank, sr)
476+
)
477+
cl_inv = list(
478+
self._channels_last_inverse_order(rank, sr)
479+
)
480+
if perm == cl_order or perm == cl_inv:
481+
input_node = node.args[0]
482+
output_shape = list(node.meta["val"].shape)
483+
with graph_module.graph.inserting_before(node):
484+
# Create a CONST_SHAPE node for the shape arg,
485+
# matching what InsertConstShapesPass does for
486+
# normal view_copy nodes. This ensures
487+
# op_view.py sees inputs[1].name as expected.
488+
const_shape_node = graph_module.graph.call_function(
489+
exir_ops.backend.tosa.CONST_SHAPE.default,
490+
(output_shape,),
491+
)
492+
const_shape_node.meta["val"] = output_shape
493+
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
494+
"tosa_dim_order", tuple(range(rank))
495+
)
496+
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
497+
view_node = graph_module.graph.call_function(
498+
exir_ops.edge.aten.view_copy.default,
499+
(input_node, const_shape_node),
500+
)
501+
view_node.meta = dict(node.meta)
502+
node.replace_all_uses_with(view_node)
503+
graph_module.graph.erase_node(node)
504+
continue
505+
352506
# Transpose views
353507
elif node.target == exir_ops.edge.aten.view_copy.default:
354508
input_node = node.args[0]

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,76 @@ def get_inputs(self) -> input_t:
177177
return (torch.rand(4, 4, 4, 4),)
178178

179179

180+
class NHWCSafeSpatialMerge(torch.nn.Module):
181+
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
182+
preserving the last-dim channel.
183+
184+
For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
185+
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
186+
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
187+
shapes are monotonic with the last dim alone, so no transposes are inserted
188+
around the view_copy.
189+
190+
Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).
191+
"""
192+
193+
ops_before_pass: Dict[str, int] = {}
194+
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
195+
ops_after_pass: Dict[str, int] = {
196+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
197+
}
198+
ops_not_after_pass: List[str] = []
199+
200+
def __init__(self):
201+
super().__init__()
202+
self.conv = torch.nn.Conv2d(
203+
in_channels=2, out_channels=2, kernel_size=1, bias=False
204+
)
205+
206+
def forward(self, x: torch.Tensor) -> torch.Tensor:
207+
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
208+
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
209+
return x + x # keep result 4-D in NHWC
210+
211+
def get_inputs(self) -> input_t:
212+
return (torch.randn(1, 2, 14, 72),)
213+
214+
215+
class NHWCUnsafeChannelChange(torch.nn.Module):
216+
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
217+
target shape cannot be produced by a monotonic merge of NHWC input dims.
218+
The pass MUST still insert transposes around the view_copy.
219+
"""
220+
221+
ops_before_pass: Dict[str, int] = {}
222+
# conv I/O transposes (2) + view_copy transposes (2) = 4
223+
ops_after_pass: Dict[str, int] = {
224+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
225+
}
226+
ops_not_after_pass: List[str] = []
227+
228+
def __init__(self):
229+
super().__init__()
230+
self.conv = torch.nn.Conv2d(
231+
in_channels=72, out_channels=72, kernel_size=1, bias=False
232+
)
233+
234+
def forward(self, x: torch.Tensor) -> torch.Tensor:
235+
x = self.conv(x) # output [1, 72, 2, 14]
236+
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
237+
return x + x
238+
239+
def get_inputs(self) -> input_t:
240+
return (torch.randn(1, 72, 2, 14),)
241+
242+
180243
modules: Dict[str, ModuleMetadata] = {
181244
"no_nhwc": NoNHWC(),
182245
"parallel_clusters": ParallelClusters(),
183246
"serial_clusters": SerialClusters(),
184247
"reshapes": Reshapes(),
248+
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
249+
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
185250
}
186251

187252

@@ -209,3 +274,66 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209274
module_nn = cast(torch.nn.Module, module)
210275
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211276
pipeline.run()
277+
278+
279+
# --- Direct unit tests for NHWC-safe reshape helpers ---
280+
281+
282+
def test_get_shape_indices_spatial_merge():
283+
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
284+
indices = ToTosaMemoryFormatPass._get_shape_indices(
285+
[1, 2, 14, 72], [1, 28, 1, 72]
286+
)
287+
assert indices == [[0], [1, 2], [], [3]]
288+
289+
290+
def test_get_shape_indices_identity():
291+
"""Same shape => each dim maps to itself."""
292+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
293+
assert indices == [[0], [1], [2]]
294+
295+
296+
def test_get_shape_indices_full_merge():
297+
"""[2, 3, 4] -> [24]: merge all dims into one."""
298+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
299+
assert indices == [[0, 1, 2]]
300+
301+
302+
def test_get_shape_indices_incompatible():
303+
"""Sizes that don't divide => None."""
304+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
305+
assert indices is None
306+
307+
308+
def test_get_shape_indices_size_one_insert():
309+
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
310+
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
311+
assert indices is not None
312+
assert indices == [[0], [], [1]]
313+
314+
315+
def test_is_monotonic_true():
316+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
317+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
318+
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])
319+
320+
321+
def test_is_monotonic_false():
322+
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
323+
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])
324+
325+
326+
def test_is_nhwc_safe_forward():
327+
"""Shapes already NHWC by the time the pass runs.
328+
[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved).
329+
"""
330+
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
331+
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
332+
)
333+
334+
335+
def test_is_nhwc_safe_non_4d():
336+
"""Reshapes below rank 4 are never NHWC-safe."""
337+
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
338+
[6, 4], [24], input_sr=0, output_sr=0
339+
)

0 commit comments

Comments
 (0)