Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 174 additions & 3 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -271,7 +272,7 @@ def insert_output_transpose(node, graph_module):
# Guard: mem_format must be a true permutation for the current rank
assert sorted(mem_format) == list(
range(rank)
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"

with graph_module.graph.inserting_after(node):
permute_node = create_node(
Expand All @@ -296,6 +297,110 @@ def insert_output_transpose(node, graph_module):
for user in users:
user.replace_input_with(node, permute_node)

@staticmethod
def _get_shape_indices(
src_shape: list[int], tgt_shape: list[int]
) -> list[list[int]] | None:
"""Greedy dimension matching for reshape operations.

For each target dimension, greedily consumes contiguous source
dimensions whose product equals the target size. Size-1 target
dimensions that do not correspond to any source dimension produce
empty index lists (inserted dims).

Returns ``None`` when no valid mapping exists.

"""
src_idx = 0
result: list[list[int]] = []

for tgt_dim in tgt_shape:
if tgt_dim <= 0:
return None

indices: list[int] = []
remaining = tgt_dim

while src_idx < len(src_shape):
if src_shape[src_idx] == 0:
return None
if remaining % src_shape[src_idx] != 0:
break
indices.append(src_idx)
remaining //= src_shape[src_idx]
src_idx += 1
if remaining == 1:
break

if remaining != 1:
return None

result.append(indices)

if src_idx != len(src_shape):
return None

return result

@staticmethod
def _is_monotonic(indices: list[list[int]]) -> bool:
"""Return ``True`` when all non-empty index groups are strictly ordered
— i.e. each group's indices follow the previous group's.
"""
last_max = -1
for group in indices:
if not group:
continue
if group[0] <= last_max:
return False
last_max = group[-1]
return True

@staticmethod
def _is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
) -> bool:
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.

By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
``meta["val"]`` are already in NHWC physical order (the channel
dimension sits at position ``rank - spatial_rank - 1``, not at
position 1 as in NCHW). We therefore check the shape indices on
the **raw** input/output shapes — no extra permutation is needed.

Returns ``True`` when:
1. The reshape has monotonic shape_indices (each output dim maps
to a contiguous, in-order group of input dims), AND
2. The channel dimension is preserved alone (not merged with
spatial dims).

"""
rank_in = len(input_shape)
rank_out = len(output_shape)
if rank_in < 4 or rank_out < 4:
return False

indices = ToTosaMemoryFormatPass._get_shape_indices(
list(input_shape), list(output_shape)
)
if indices is None:
return False

if not ToTosaMemoryFormatPass._is_monotonic(indices):
return False

# In the TOSA pipeline the physical memory order is NHWC.
# The channel dimension in NHWC is always the **last** axis
# (position ``rank - 1``). It must appear *alone* in its
# output group — if it is merged with spatial dims the reshape
# would reorder channel data and the optimisation is invalid.
channel_idx = rank_in - 1
for group in indices:
if channel_idx in group:
return len(group) == 1
# Channel dim not consumed by any group — conservative reject.
return False

@staticmethod
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
Expand All @@ -317,6 +422,14 @@ def _insert_view_transpose(
output_sr,
)

# When the NHWC-space reshape has monotonic shape_indices the
# view_copy can operate directly on NHWC data — no transposes
# are needed.
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr
):
return

if (
channel_reshape or nhwc_to_nchw
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
Expand All @@ -329,6 +442,61 @@ def _insert_view_transpose(
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

def _try_replace_redundant_permute(
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
) -> bool:
"""Replace a permute_copy with view_copy if it duplicates
tosa_dim_order.

When a permute_copy's permutation matches the channels-last order
(or its inverse), the permute does the same NCHW<>NHWC conversion
that tosa_dim_order already handles — keeping both would
double-convert. Replace with view_copy (identity reshape).

Returns ``True`` if the node was replaced and erased.

"""
if node.target not in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.permute.default,
):
return False

perm_arg = node.args[1]
assert isinstance(perm_arg, (list, tuple))
perm = list(perm_arg)
rank = len(perm)
sr = node.meta.get("tosa_spatial_rank", 0)

if rank < 3 or sr < 1:
return False

cl_order = list(self._channels_last_order(rank, sr))
cl_inv = list(self._channels_last_inverse_order(rank, sr))
if perm != cl_order and perm != cl_inv:
return False

input_node = node.args[0]
output_shape = list(node.meta["val"].shape)
with graph_module.graph.inserting_before(node):
const_shape_node = graph_module.graph.call_function(
exir_ops.backend.tosa.CONST_SHAPE.default,
(output_shape,),
)
const_shape_node.meta["val"] = output_shape
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
"tosa_dim_order", tuple(range(rank))
)
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
view_node = graph_module.graph.call_function(
exir_ops.edge.aten.view_copy.default,
(input_node, const_shape_node),
)
view_node.meta = dict(node.meta)
node.replace_all_uses_with(view_node)
graph_module.graph.erase_node(node)
return True

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""Transposes are needed for operators transforming the input to a
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
Expand All @@ -345,12 +513,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
- 1D/2D tensors

"""
for node in graph_module.graph.nodes:
for node in list(graph_module.graph.nodes):
if node.op != "call_function":
continue

if self._try_replace_redundant_permute(node, graph_module):
continue

# Transpose views
elif node.target == exir_ops.edge.aten.view_copy.default:
if node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
output_shape = node.meta["val"].shape
Expand Down
132 changes: 132 additions & 0 deletions backends/arm/test/passes/test_to_tosa_memory_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,79 @@ def get_inputs(self) -> input_t:
return (torch.rand(4, 4, 4, 4),)


class NHWCSafeSpatialMerge(torch.nn.Module):
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
preserving the last-dim channel.

For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
shapes are monotonic with the last dim alone, so no transposes are inserted
around the view_copy.

Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).

"""

ops_before_pass: Dict[str, int] = {}
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
ops_after_pass: Dict[str, int] = {
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
}
ops_not_after_pass: List[str] = []

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=2, out_channels=2, kernel_size=1, bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
return x + x # keep result 4-D in NHWC

def get_inputs(self) -> input_t:
return (torch.randn(1, 2, 14, 72),)


class NHWCUnsafeChannelChange(torch.nn.Module):
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
target shape cannot be produced by a monotonic merge of NHWC input dims.

The pass MUST still insert transposes around the view_copy.

"""

ops_before_pass: Dict[str, int] = {}
# conv I/O transposes (2) + view_copy transposes (2) = 4
ops_after_pass: Dict[str, int] = {
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
}
ops_not_after_pass: List[str] = []

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=72, out_channels=72, kernel_size=1, bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x) # output [1, 72, 2, 14]
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
return x + x

def get_inputs(self) -> input_t:
return (torch.randn(1, 72, 2, 14),)


modules: Dict[str, ModuleMetadata] = {
"no_nhwc": NoNHWC(),
"parallel_clusters": ParallelClusters(),
"serial_clusters": SerialClusters(),
"reshapes": Reshapes(),
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
}


Expand Down Expand Up @@ -209,3 +277,67 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
module_nn = cast(torch.nn.Module, module)
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
pipeline.run()


# --- Direct unit tests for NHWC-safe reshape helpers ---


def test_get_shape_indices_spatial_merge():
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
indices = ToTosaMemoryFormatPass._get_shape_indices([1, 2, 14, 72], [1, 28, 1, 72])
assert indices == [[0], [1, 2], [], [3]]


def test_get_shape_indices_identity():
"""Same shape => each dim maps to itself."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
assert indices == [[0], [1], [2]]


def test_get_shape_indices_full_merge():
"""[2, 3, 4] -> [24]: merge all dims into one."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
assert indices == [[0, 1, 2]]


def test_get_shape_indices_incompatible():
"""Sizes that don't divide => None."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
assert indices is None


def test_get_shape_indices_size_one_insert():
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
assert indices is not None
assert indices == [[0], [], [1]]


def test_is_monotonic_true():
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])


def test_is_monotonic_false():
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])


def test_is_nhwc_safe_forward():
"""Shapes already NHWC by the time the pass runs.

[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72
preserved).

"""
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
)


def test_is_nhwc_safe_non_4d():
"""Reshapes below rank 4 are never NHWC-safe."""
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
[6, 4], [24], input_sr=0, output_sr=0
)
Loading