Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions backends/arm/_passes/size_adjust_input_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import cast, Sequence, Set, Type, TypeAlias

import torch.fx
Expand All @@ -13,10 +11,12 @@
expand_around_channel,
)
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm.tosa.specification import get_context_shape_env
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Slices: TypeAlias = list[tuple[int, int, int]]
SymIntLike = int | torch.SymInt

conv2d_op = exir_ops.edge.aten.convolution.default
max_pooling_op = exir_ops.edge.aten.max_pool2d.default
Expand All @@ -26,20 +26,34 @@
valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op]


def conv_remainder(input_length, pad, dilation, weight, stride) -> int:
def conv_remainder(
input_length: SymIntLike, pad: int, dilation: int, weight: int, stride: int
) -> SymIntLike:
"""Returns the remainder of input_length; given the padding, dilation,
stride, and kernel size.
"""
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride


def pooling_remainder(input_size, pad, kernel_size, stride) -> int:
def pooling_remainder(
input_size: SymIntLike, pad: int, kernel_size: int, stride: int
) -> SymIntLike:
"""Returns the remainder of input_length; given the padding, stride, and
kernel size.
"""
return (input_size + 2 * pad - kernel_size) % stride


def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
"""Returns whether an int or SymInt is greater than another value."""
if isinstance(input, torch.SymInt):
shape_env = get_context_shape_env()
value_ranges = shape_env.bound_sympy(input.node.expr)
return value_ranges.upper > other
else:
return input > other


def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
slices = []

Expand All @@ -59,7 +73,7 @@ def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
remainder = conv_remainder(
input_shape[dim], pad, dilation, weight_shape[dim], stride
)
if remainder > pad:
if _greater_than(remainder, pad):
adjustment = remainder - pad
args = (dim, 0, input_shape[dim] - adjustment)
slices.append(args)
Expand Down Expand Up @@ -87,7 +101,7 @@ def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices:
remainder = pooling_remainder(
input_shape[dim], pad_size, kernel_length, stride_length
)
if remainder > pad_size:
if _greater_than(remainder, pad_size):
adjustment = remainder - pad_size
args = (dim, 0, input_shape[dim] - adjustment)
slices.append(args)
Expand Down
Loading
Loading