Skip to content

Commit d130d50

Browse files
authored
Fix depthwise conv detection for groups==in_channels==1
Differential Revision: D93869048 Pull Request resolved: pytorch#17590
1 parent e1ecac0 commit d130d50

8 files changed

Lines changed: 276 additions & 38 deletions

File tree

backends/cadence/aot/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ fbcode_target(_kind = runtime.python_library,
132132
typing = True,
133133
deps = [
134134
"fbcode//caffe2:torch",
135+
"fbcode//executorch/backends/cadence/aot:utils",
135136
"fbcode//executorch/exir:scalar_type",
136137
"fbcode//executorch/kernels/quantized:custom_ops_generated_lib",
137138
],
@@ -374,6 +375,7 @@ fbcode_target(_kind = runtime.python_library,
374375
deps = [
375376
"//caffe2:torch",
376377
"//executorch/backends/cadence/aot:pass_utils",
378+
"//executorch/backends/cadence/aot:utils",
377379
"//executorch/exir:pass_base",
378380
],
379381
)

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_conv1d_output_size,
1515
get_conv2d_output_size,
1616
get_im2row_output_size,
17+
is_depthwise_conv,
1718
)
1819
from executorch.exir.scalar_type import ScalarType
1920
from torch._meta_registrations import _linalg_svd_meta
@@ -1034,11 +1035,8 @@ def quantized_conv2d_nhwc_meta(
10341035
assert len(in_size) < 6
10351036

10361037
# Determine weight layout based on depthwise vs regular conv.
1037-
# Depthwise is defined by in_channels == groups, where in_channels
1038-
# is the last dim of the NHWC input.
10391038
in_channels = in_size[-1]
1040-
is_depthwise = in_channels == groups
1041-
if is_depthwise:
1039+
if is_depthwise_conv(groups, in_channels):
10421040
# Depthwise conv: weight is [*kernel_size, OC]
10431041
*kernel_size, out_channels = weight.shape
10441042
else:
@@ -1177,12 +1175,8 @@ def quantized_conv2d_nhwc_per_tensor_meta(
11771175
assert len(in_size) < 6
11781176

11791177
# Determine weight layout based on depthwise vs regular conv.
1180-
# Depthwise is defined by in_channels == groups, where in_channels
1181-
# is the last dim of the NHWC input.
11821178
in_channels = in_size[-1]
1183-
is_depthwise = in_channels == groups
1184-
if is_depthwise:
1185-
# Depthwise conv: weight is [*kernel_size, OC]
1179+
if is_depthwise_conv(groups, in_channels):
11861180
*kernel_size, out_channels = weight.shape
11871181
elif len(in_size) == 3:
11881182
# 1D conv: weight is [OC, K, IC]
@@ -1336,12 +1330,9 @@ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
13361330
assert len(in_size) > 2
13371331
assert len(in_size) < 6
13381332

1339-
# Determine weight layout based on input and weight dimensions:
1340-
# - Depthwise conv: input is 3D/4D, weight is 2/3D [K, OC]/[KH, KW, OC]
1341-
# - 1D conv: input is 3D, weight is 3D [OC, K, IC]
1342-
# - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC]
1343-
if len(weight.shape) == 3:
1344-
# 2D depthwise conv: weight is [KH, KW, OC]
1333+
# Determine weight layout based on depthwise vs regular conv.
1334+
in_channels = in_size[-1]
1335+
if is_depthwise_conv(groups, in_channels):
13451336
*kernel_size, out_channels = weight.shape
13461337
elif len(in_size) == 3:
13471338
# 1D conv: weight is [OC, K, IC]
@@ -1397,12 +1388,9 @@ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
13971388
assert len(in_size) > 2
13981389
assert len(in_size) < 6
13991390

1400-
# Determine weight layout based on input and weight dimensions:
1401-
# - Depthwise conv: input is 3D/4D, weight is 3D [KH, KW, OC]
1402-
# - 1D conv: input is 3D, weight is 3D [OC, K, IC]
1403-
# - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC]
1404-
if len(weight.shape) == 3:
1405-
# 2D depthwise conv: weight is [KH, KW, OC]
1391+
# Determine weight layout based on depthwise vs regular conv.
1392+
in_channels = in_size[-1]
1393+
if is_depthwise_conv(groups, in_channels):
14061394
*kernel_size, out_channels = weight.shape
14071395
elif len(in_size) == 3:
14081396
# 1D conv: weight is [OC, K, IC]

backends/cadence/aot/ref_implementations.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15+
from executorch.backends.cadence.aot.utils import is_depthwise_conv
1516
from executorch.exir.scalar_type import ScalarType
1617
from torch.library import impl, Library
1718

@@ -1104,17 +1105,12 @@ def quantized_conv2d_nhwc_per_tensor(
11041105

11051106
# Convert to NCHW format to reuse the existing implementation
11061107
in_channels = input_tensor.shape[-1]
1107-
# Depthwise weights have one fewer dimension than the input because the IC
1108-
# dimension (always 1) was squeezed out during the NCHW->NHWC conversion in
1109-
# replace_ops.py. E.g. 2D depthwise: weight is [KH, KW, OC] (3D) while
1110-
# input is [N, H, W, C] (4D). A regular conv with in_channels==groups==1
1111-
# still has 4D weights [OC, KH, KW, IC].
1112-
is_depthwise = in_channels == groups and weight.dim() < input_tensor.dim()
1108+
depthwise = is_depthwise_conv(groups, in_channels)
11131109

11141110
if len(input_tensor.shape) == 3:
11151111
# 1D conv: input is [N, L, C] -> [N, C, L]
11161112
input_tensor = input_tensor.movedim(-1, 1).contiguous()
1117-
if is_depthwise:
1113+
if depthwise:
11181114
# 1D depthwise: weight is [K, OC] -> [OC, 1, K]
11191115
weight = weight.permute(1, 0).unsqueeze(1).contiguous()
11201116
else:
@@ -1124,7 +1120,7 @@ def quantized_conv2d_nhwc_per_tensor(
11241120
else:
11251121
# 2D conv: input is [N, H, W, C] -> [N, C, H, W]
11261122
input_tensor = input_tensor.movedim(-1, -3)
1127-
if is_depthwise:
1123+
if depthwise:
11281124
# 2D depthwise: weight is [KH, KW, OC] -> [OC, 1, KH, KW]
11291125
weight = weight.permute(2, 0, 1).unsqueeze(1).contiguous()
11301126
else:

backends/cadence/aot/replace_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
register_cadence_pass,
2727
RemoveOrReplacePassInterface,
2828
)
29+
from executorch.backends.cadence.aot.utils import is_depthwise_conv
2930
from executorch.backends.transforms.replace_scalar_with_tensor import (
3031
ReplaceScalarWithTensorArgPass,
3132
)
@@ -1138,19 +1139,19 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11381139

11391140
# Check if this is a depthwise convolution (groups == input_channels)
11401141
# and weight is 4D with shape [OC, 1, KH, KW]
1141-
groups = node.args[6]
1142+
groups = cast(int, node.args[6])
11421143
input_shape = input_node.meta["val"].shape
11431144
weight_shape = weight_node.meta["val"].shape
11441145
input_channels = input_shape[1] # NCHW format, channels at index 1
1145-
# Depthwise conv has 4D weight [OC, 1, KH, KW] where the IC dim is 1
1146-
is_depthwise = groups == input_channels and weight_shape[1] == 1
1146+
# NCHW: also verify weight IC dim == 1.
1147+
depthwise = is_depthwise_conv(groups, input_channels) and weight_shape[1] == 1
11471148
is_2d = len(input_shape) == 4
11481149
# Insert transpose operations before the node
11491150
with graph.inserting_before(node):
11501151
# Convert input from NCHW to NHWC
11511152
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
11521153
# Convert weight from NCHW to the appropriate format
1153-
if is_depthwise:
1154+
if depthwise:
11541155
# For depthwise: [OC, 1, KH, KW] -> [KH, KW, OC] for NNLib
11551156
weight_nhwc = self._change_depthwise_weight_to_hwc(
11561157
graph, weight_node, is_2d

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616
from executorch.backends.cadence.aot.typing_stubs import expand
17+
from executorch.backends.cadence.aot.utils import is_depthwise_conv
1718

1819
from executorch.exir.scalar_type import ScalarType
1920

@@ -942,12 +943,22 @@ def test_quantized_conv_per_tensor(
942943
assert memory_format in [torch.contiguous_format, torch.channels_last]
943944

944945
if memory_format == torch.channels_last:
946+
in_channels = input_tensor.shape[1] # NCHW still at this point
947+
depthwise = is_depthwise_conv(groups, in_channels)
945948
if input_tensor.ndim == 3:
946949
input_tensor = input_tensor.movedim(1, -1)
947-
weight = weight.movedim(1, -1)
950+
if depthwise:
951+
# [OC, 1, K] -> [K, OC] (squeeze IC, move OC to end)
952+
weight = weight.squeeze(1).movedim(0, -1)
953+
else:
954+
weight = weight.movedim(1, -1)
948955
else:
949956
input_tensor = input_tensor.movedim(-3, -1)
950-
weight = weight.movedim(-3, -1)
957+
if depthwise:
958+
# [OC, 1, KH, KW] -> [KH, KW, OC] (squeeze IC, move OC to end)
959+
weight = weight.squeeze(1).movedim(0, -1)
960+
else:
961+
weight = weight.movedim(-3, -1)
951962

952963
convs = [
953964
(

0 commit comments

Comments
 (0)