Skip to content

Commit ec1f9ae

Browse files
3l1facebook-github-bot
authored andcommitted
Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass (#18167)
Summary: Two optimizations in ToTosaMemoryFormatPass to reduce TOSA TRANSPOSE nodes: 1. **NHWC-safe reshape detection:** When a 4D→4D view_copy has monotonic shape_indices on the raw shapes and preserves the last dimension (NHWC channel), skip inserting input/output transposes. The view_copy can operate directly on NHWC data. 2. **Redundant permute_copy elimination:** Model-level permute_copy ops whose permutation matches channels_last_order (NCHW→NHWC) or its inverse (NHWC→NCHW) are redundant with the tosa_dim_order annotation that already handles format conversion. Replace them with view_copy (identity reshape) to avoid generating TOSA TRANSPOSE nodes. Reviewed By: digantdesai Differential Revision: D96432610
1 parent 7134708 commit ec1f9ae

2 files changed

Lines changed: 304 additions & 3 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 172 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
1818
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
19+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1920
from executorch.exir import ExportedProgram
2021
from executorch.exir.dialects._ops import ops as exir_ops
2122
from executorch.exir.pass_base import ExportPass, PassResult
@@ -271,7 +272,7 @@ def insert_output_transpose(node, graph_module):
271272
# Guard: mem_format must be a true permutation for the current rank
272273
assert sorted(mem_format) == list(
273274
range(rank)
274-
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
275+
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"
275276

276277
with graph_module.graph.inserting_after(node):
277278
permute_node = create_node(
@@ -296,6 +297,110 @@ def insert_output_transpose(node, graph_module):
296297
for user in users:
297298
user.replace_input_with(node, permute_node)
298299

300+
@staticmethod
301+
def _get_shape_indices(
302+
src_shape: list[int], tgt_shape: list[int]
303+
) -> list[list[int]] | None:
304+
"""Greedy dimension matching for reshape operations.
305+
306+
For each target dimension, greedily consumes contiguous source
307+
dimensions whose product equals the target size. Size-1 target
308+
dimensions that do not correspond to any source dimension produce
309+
empty index lists (inserted dims).
310+
311+
Returns ``None`` when no valid mapping exists.
312+
313+
"""
314+
src_idx = 0
315+
result: list[list[int]] = []
316+
317+
for tgt_dim in tgt_shape:
318+
if tgt_dim <= 0:
319+
return None
320+
321+
indices: list[int] = []
322+
remaining = tgt_dim
323+
324+
while src_idx < len(src_shape):
325+
if src_shape[src_idx] == 0:
326+
return None
327+
if remaining % src_shape[src_idx] != 0:
328+
break
329+
indices.append(src_idx)
330+
remaining //= src_shape[src_idx]
331+
src_idx += 1
332+
if remaining == 1:
333+
break
334+
335+
if remaining != 1:
336+
return None
337+
338+
result.append(indices)
339+
340+
if src_idx != len(src_shape):
341+
return None
342+
343+
return result
344+
345+
@staticmethod
346+
def _is_monotonic(indices: list[list[int]]) -> bool:
347+
"""Return ``True`` when all non-empty index groups are strictly ordered
348+
— i.e. each group's indices follow the previous group's.
349+
"""
350+
last_max = -1
351+
for group in indices:
352+
if not group:
353+
continue
354+
if group[0] <= last_max:
355+
return False
356+
last_max = group[-1]
357+
return True
358+
359+
@staticmethod
360+
def _is_nhwc_safe_reshape(
361+
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
362+
) -> bool:
363+
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.
364+
365+
By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
366+
``meta["val"]`` are already in NHWC physical order (the channel
367+
dimension sits at position ``rank - spatial_rank - 1``, not at
368+
position 1 as in NCHW). We therefore check the shape indices on
369+
the **raw** input/output shapes — no extra permutation is needed.
370+
371+
Returns ``True`` when:
372+
1. The reshape has monotonic shape_indices (each output dim maps
373+
to a contiguous, in-order group of input dims), AND
374+
2. The channel dimension is preserved alone (not merged with
375+
spatial dims).
376+
377+
"""
378+
rank_in = len(input_shape)
379+
rank_out = len(output_shape)
380+
if rank_in < 4 or rank_out < 4:
381+
return False
382+
383+
indices = ToTosaMemoryFormatPass._get_shape_indices(
384+
list(input_shape), list(output_shape)
385+
)
386+
if indices is None:
387+
return False
388+
389+
if not ToTosaMemoryFormatPass._is_monotonic(indices):
390+
return False
391+
392+
# In the TOSA pipeline the physical memory order is NHWC.
393+
# The channel dimension in NHWC is always the **last** axis
394+
# (position ``rank - 1``). It must appear *alone* in its
395+
# output group — if it is merged with spatial dims the reshape
396+
# would reorder channel data and the optimisation is invalid.
397+
channel_idx = rank_in - 1
398+
for group in indices:
399+
if channel_idx in group:
400+
return len(group) == 1
401+
# Channel dim not consumed by any group — conservative reject.
402+
return False
403+
299404
@staticmethod
300405
def _insert_view_transpose(
301406
input_shape, output_shape, node, input_node, graph_module
@@ -317,6 +422,14 @@ def _insert_view_transpose(
317422
output_sr,
318423
)
319424

425+
# When the NHWC-space reshape has monotonic shape_indices the
426+
# view_copy can operate directly on NHWC data — no transposes
427+
# are needed.
428+
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
429+
input_shape, output_shape, input_sr, output_sr
430+
):
431+
return
432+
320433
if (
321434
channel_reshape or nhwc_to_nchw
322435
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
@@ -329,6 +442,59 @@ def _insert_view_transpose(
329442
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
330443
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
331444

445+
def _try_replace_redundant_permute(
446+
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
447+
) -> bool:
448+
"""Replace a permute_copy with view_copy if it duplicates tosa_dim_order.
449+
450+
When a permute_copy's permutation matches the channels-last order
451+
(or its inverse), the permute does the same NCHW<>NHWC conversion
452+
that tosa_dim_order already handles — keeping both would
453+
double-convert. Replace with view_copy (identity reshape).
454+
455+
Returns ``True`` if the node was replaced and erased.
456+
"""
457+
if node.target not in (
458+
exir_ops.edge.aten.permute_copy.default,
459+
exir_ops.edge.aten.permute.default,
460+
):
461+
return False
462+
463+
perm = list(node.args[1])
464+
rank = len(perm)
465+
sr = node.meta.get("tosa_spatial_rank", 0)
466+
467+
if rank < 3 or sr < 1:
468+
return False
469+
470+
cl_order = list(self._channels_last_order(rank, sr))
471+
cl_inv = list(self._channels_last_inverse_order(rank, sr))
472+
if perm != cl_order and perm != cl_inv:
473+
return False
474+
475+
input_node = node.args[0]
476+
output_shape = list(node.meta["val"].shape)
477+
with graph_module.graph.inserting_before(node):
478+
const_shape_node = graph_module.graph.call_function(
479+
exir_ops.backend.tosa.CONST_SHAPE.default,
480+
(output_shape,),
481+
)
482+
const_shape_node.meta["val"] = output_shape
483+
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
484+
"tosa_dim_order", tuple(range(rank))
485+
)
486+
const_shape_node.meta[TosaSpecialDtype.meta_key()] = (
487+
TosaSpecialDtype.SHAPE
488+
)
489+
view_node = graph_module.graph.call_function(
490+
exir_ops.edge.aten.view_copy.default,
491+
(input_node, const_shape_node),
492+
)
493+
view_node.meta = dict(node.meta)
494+
node.replace_all_uses_with(view_node)
495+
graph_module.graph.erase_node(node)
496+
return True
497+
332498
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
333499
"""Transposes are needed for operators transforming the input to a
334500
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
@@ -345,12 +511,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
345511
- 1D/2D tensors
346512
347513
"""
348-
for node in graph_module.graph.nodes:
514+
for node in list(graph_module.graph.nodes):
349515
if node.op != "call_function":
350516
continue
351517

518+
if self._try_replace_redundant_permute(node, graph_module):
519+
continue
520+
352521
# Transpose views
353-
elif node.target == exir_ops.edge.aten.view_copy.default:
522+
if node.target == exir_ops.edge.aten.view_copy.default:
354523
input_node = node.args[0]
355524
input_shape = input_node.meta["val"].shape
356525
output_shape = node.meta["val"].shape

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,79 @@ def get_inputs(self) -> input_t:
177177
return (torch.rand(4, 4, 4, 4),)
178178

179179

180+
class NHWCSafeSpatialMerge(torch.nn.Module):
181+
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
182+
preserving the last-dim channel.
183+
184+
For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
185+
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
186+
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
187+
shapes are monotonic with the last dim alone, so no transposes are inserted
188+
around the view_copy.
189+
190+
Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).
191+
192+
"""
193+
194+
ops_before_pass: Dict[str, int] = {}
195+
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
196+
ops_after_pass: Dict[str, int] = {
197+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
198+
}
199+
ops_not_after_pass: List[str] = []
200+
201+
def __init__(self):
202+
super().__init__()
203+
self.conv = torch.nn.Conv2d(
204+
in_channels=2, out_channels=2, kernel_size=1, bias=False
205+
)
206+
207+
def forward(self, x: torch.Tensor) -> torch.Tensor:
208+
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
209+
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
210+
return x + x # keep result 4-D in NHWC
211+
212+
def get_inputs(self) -> input_t:
213+
return (torch.randn(1, 2, 14, 72),)
214+
215+
216+
class NHWCUnsafeChannelChange(torch.nn.Module):
217+
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
218+
target shape cannot be produced by a monotonic merge of NHWC input dims.
219+
220+
The pass MUST still insert transposes around the view_copy.
221+
222+
"""
223+
224+
ops_before_pass: Dict[str, int] = {}
225+
# conv I/O transposes (2) + view_copy transposes (2) = 4
226+
ops_after_pass: Dict[str, int] = {
227+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
228+
}
229+
ops_not_after_pass: List[str] = []
230+
231+
def __init__(self):
232+
super().__init__()
233+
self.conv = torch.nn.Conv2d(
234+
in_channels=72, out_channels=72, kernel_size=1, bias=False
235+
)
236+
237+
def forward(self, x: torch.Tensor) -> torch.Tensor:
238+
x = self.conv(x) # output [1, 72, 2, 14]
239+
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
240+
return x + x
241+
242+
def get_inputs(self) -> input_t:
243+
return (torch.randn(1, 72, 2, 14),)
244+
245+
180246
modules: Dict[str, ModuleMetadata] = {
181247
"no_nhwc": NoNHWC(),
182248
"parallel_clusters": ParallelClusters(),
183249
"serial_clusters": SerialClusters(),
184250
"reshapes": Reshapes(),
251+
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
252+
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
185253
}
186254

187255

@@ -209,3 +277,67 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209277
module_nn = cast(torch.nn.Module, module)
210278
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211279
pipeline.run()
280+
281+
282+
# --- Direct unit tests for NHWC-safe reshape helpers ---
283+
284+
285+
def test_get_shape_indices_spatial_merge():
286+
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
287+
indices = ToTosaMemoryFormatPass._get_shape_indices([1, 2, 14, 72], [1, 28, 1, 72])
288+
assert indices == [[0], [1, 2], [], [3]]
289+
290+
291+
def test_get_shape_indices_identity():
292+
"""Same shape => each dim maps to itself."""
293+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
294+
assert indices == [[0], [1], [2]]
295+
296+
297+
def test_get_shape_indices_full_merge():
298+
"""[2, 3, 4] -> [24]: merge all dims into one."""
299+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
300+
assert indices == [[0, 1, 2]]
301+
302+
303+
def test_get_shape_indices_incompatible():
304+
"""Sizes that don't divide => None."""
305+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
306+
assert indices is None
307+
308+
309+
def test_get_shape_indices_size_one_insert():
310+
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
311+
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
312+
assert indices is not None
313+
assert indices == [[0], [], [1]]
314+
315+
316+
def test_is_monotonic_true():
317+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
318+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
319+
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])
320+
321+
322+
def test_is_monotonic_false():
323+
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
324+
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])
325+
326+
327+
def test_is_nhwc_safe_forward():
328+
"""Shapes already NHWC by the time the pass runs.
329+
330+
[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72
331+
preserved).
332+
333+
"""
334+
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
335+
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
336+
)
337+
338+
339+
def test_is_nhwc_safe_non_4d():
340+
"""Reshapes below rank 4 are never NHWC-safe."""
341+
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
342+
[6, 4], [24], input_sr=0, output_sr=0
343+
)

0 commit comments

Comments
 (0)