Skip to content

Commit 5123efe

Browse files
authored
Arm backend: Add partial support for grouped tranposed convs (pytorch#17702)
Grouped transposed convs with per-channel quantization and/or dilation are not yet supported. cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell
1 parent 17cb87c commit 5123efe

4 files changed

Lines changed: 302 additions & 134 deletions

File tree

backends/arm/_passes/decompose_grouped_conv_pass.py

Lines changed: 201 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from copy import copy
7-
from typing import Set, Type
7+
from typing import Literal, Protocol, Set, Type, TypeGuard
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
@@ -14,6 +14,16 @@
1414
from executorch.exir.pass_base import ExportPass
1515

1616

17+
class _PerChannelQuantArgs(Protocol):
18+
scale: list[float]
19+
zp: list[int]
20+
qmin: int
21+
qmax: int
22+
dtype: torch.dtype
23+
axis: int
24+
per_channel: Literal[True]
25+
26+
1727
class DecomposeGroupedConvPass(ArmPass):
1828
"""Splits a grouped convolution which is not supported by TOSA into multiple
1929
convolutions using slice->conv->cat.
@@ -47,6 +57,12 @@ def _get_decomposition(op):
4757
exir_ops.edge.aten.convolution.default,
4858
exir_ops.edge.aten.cat.default,
4959
)
60+
case torch.ops.aten.conv_transpose2d.input:
61+
return (
62+
torch.ops.aten.slice_copy.Tensor,
63+
torch.ops.aten.conv_transpose2d.input,
64+
torch.ops.aten.cat.default,
65+
)
5066
case torch.ops.aten.conv2d.default:
5167
return (
5268
torch.ops.aten.slice_copy.Tensor,
@@ -57,131 +73,233 @@ def _get_decomposition(op):
5773
raise RuntimeError("Invalid op for grouped conv decomposition")
5874

5975
@staticmethod
60-
def _split_per_channel_qparams(qarg, index, output_slice_size):
61-
if qarg is not None and qarg.per_channel:
62-
start_index = index * output_slice_size
63-
stop_index = (index + 1) * output_slice_size
64-
return QuantArgs(
65-
scale=qarg.scale[start_index:stop_index],
66-
zp=qarg.zp[start_index:stop_index],
67-
qmin=qarg.qmin,
68-
qmax=qarg.qmax,
69-
dtype=qarg.dtype,
70-
axis=qarg.axis,
71-
per_channel=qarg.per_channel,
76+
def _get_groups_and_transposed(op, args):
77+
if op == exir_ops.edge.aten.convolution.default:
78+
return args[8], args[6]
79+
if op == torch.ops.aten.conv_transpose2d.input:
80+
return args[6], True
81+
if op == torch.ops.aten.conv2d.default:
82+
return args[6], False
83+
return None, None
84+
85+
@staticmethod
86+
def _is_depthwise_conv(input_node, groups, transposed):
87+
return (not transposed) and input_node.data.shape[1] == groups
88+
89+
@staticmethod
90+
def _get_slice_sizes(weight_node, groups, transposed):
91+
if transposed:
92+
input_slice_size = weight_node.data.shape[0] // groups
93+
output_slice_size = weight_node.data.shape[1]
94+
else:
95+
input_slice_size = weight_node.data.shape[1]
96+
output_slice_size = weight_node.data.shape[0] // groups
97+
return input_slice_size, output_slice_size
98+
99+
def _slice_inputs(
100+
self, slice_op, input_node, input_slice_size, groups, meta, kwargs
101+
):
102+
input_slices = []
103+
for i in range(groups):
104+
start_index = i * input_slice_size
105+
stop_index = (i + 1) * input_slice_size
106+
slice_args = (input_node, 1, start_index, stop_index)
107+
input_slices.append(
108+
super().call_operator(slice_op, slice_args, kwargs, meta, updated=True)
72109
)
73-
return qarg
110+
return input_slices
111+
112+
def _slice_weights(
113+
self,
114+
slice_op,
115+
weight_node,
116+
groups,
117+
input_slice_size,
118+
output_slice_size,
119+
transposed,
120+
meta,
121+
kwargs,
122+
):
123+
weight_slices = []
124+
for i in range(groups):
125+
if transposed:
126+
start_index = i * input_slice_size
127+
stop_index = (i + 1) * input_slice_size
128+
else:
129+
start_index = i * output_slice_size
130+
stop_index = (i + 1) * output_slice_size
131+
slice_args = (weight_node, 0, start_index, stop_index)
132+
weight_slices.append(
133+
super().call_operator(slice_op, slice_args, kwargs, meta, updated=True)
134+
)
135+
return weight_slices
136+
137+
def _slice_biases(
138+
self, slice_op, bias_node, groups, output_slice_size, meta, kwargs
139+
):
140+
bias_slices = []
141+
for i in range(groups):
142+
if bias_node is None:
143+
bias_slices.append(None)
144+
continue
145+
start_index = i * output_slice_size
146+
stop_index = (i + 1) * output_slice_size
147+
slice_args = (bias_node, 0, start_index, stop_index)
148+
bias_slices.append(
149+
super().call_operator(slice_op, slice_args, kwargs, meta, updated=True)
150+
)
151+
return bias_slices
74152

75153
@staticmethod
76-
def _get_meta_copy(meta, i, output_slice_size):
154+
def _build_conv_args(op, args, input_slice, filter_slice, bias_slice):
155+
if op == exir_ops.edge.aten.convolution.default:
156+
return (input_slice, filter_slice, bias_slice, *args[3:8], 1)
157+
if op == torch.ops.aten.conv_transpose2d.input:
158+
return (
159+
input_slice,
160+
filter_slice,
161+
bias_slice,
162+
args[3],
163+
args[4],
164+
args[5],
165+
1,
166+
args[7],
167+
)
168+
if op == torch.ops.aten.conv2d.default:
169+
return (input_slice, filter_slice, bias_slice, *args[3:6], 1)
170+
raise RuntimeError("Invalid op for grouped conv decomposition")
171+
172+
@staticmethod
173+
def _is_per_channel_qparams(
174+
qarg: QuantArgs | None,
175+
) -> TypeGuard[_PerChannelQuantArgs]:
176+
return qarg is not None and qarg.per_channel
177+
178+
@staticmethod
179+
def _split_per_channel_qparams(
180+
qarg: _PerChannelQuantArgs, start_index, stop_index
181+
) -> QuantArgs:
182+
return QuantArgs(
183+
scale=qarg.scale[start_index:stop_index],
184+
zp=qarg.zp[start_index:stop_index],
185+
qmin=qarg.qmin,
186+
qmax=qarg.qmax,
187+
dtype=qarg.dtype,
188+
axis=qarg.axis,
189+
per_channel=qarg.per_channel,
190+
)
191+
192+
@staticmethod
193+
def _get_meta_copy(
194+
meta,
195+
i,
196+
input_slice_size,
197+
output_slice_size,
198+
transposed,
199+
):
77200
meta_copy = meta.copy()
78201

79202
if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0:
80203
# Handle per-channel quantization by splitting quantization params
81204
# similarly to how activations/weights/biases are split.
82205
new_qparams = meta.data.get("input_qparams").copy()
206+
83207
# Get quantization params of the weights and slice them.
84208
w_qarg = new_qparams[1]
85-
new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams(
86-
w_qarg, index=i, output_slice_size=output_slice_size
87-
)
88-
# Special case for int16, grouped conv2d when bias is included.
89-
# As we add bias after in the DecomposeConv2dWithInt16ActivationPass we must
90-
# also split the bias quantization parameters for bias.
91-
if new_qparams[0].dtype == torch.int16 and len(new_qparams) > 2:
92-
b_qarg = new_qparams[2]
93-
new_qparams[2] = DecomposeGroupedConvPass._split_per_channel_qparams(
94-
b_qarg, index=i, output_slice_size=output_slice_size
209+
if DecomposeGroupedConvPass._is_per_channel_qparams(w_qarg):
210+
211+
# For transpose conv, axis=1 corresponds to output channels and
212+
# does not align with grouped slicing.
213+
# Per-channel quantization on axis=0 on the other hand could align here but
214+
# per-channel quant on axis 0 is very uncommon.
215+
if transposed:
216+
raise RuntimeError(
217+
"Grouped transpose conv with per-channel quantization is unsupported"
218+
)
219+
220+
slice_size = output_slice_size
221+
start_index = i * slice_size
222+
stop_index = (i + 1) * slice_size
223+
new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams(
224+
w_qarg, start_index=start_index, stop_index=stop_index
95225
)
96226

227+
# Split per-channel bias qparams to match per-group output slices.
228+
if len(new_qparams) > 2:
229+
b_qarg = new_qparams[2]
230+
if DecomposeGroupedConvPass._is_per_channel_qparams(b_qarg):
231+
start_index = i * output_slice_size
232+
stop_index = (i + 1) * output_slice_size
233+
new_qparams[2] = (
234+
DecomposeGroupedConvPass._split_per_channel_qparams(
235+
b_qarg, start_index=start_index, stop_index=stop_index
236+
)
237+
)
238+
97239
meta_copy.data["input_qparams"] = new_qparams
98240

99241
return meta_copy
100242

101243
def call_operator(self, op, args, kwargs, meta):
102-
if op == exir_ops.edge.aten.convolution.default:
103-
groups = args[8]
104-
transposed = args[6]
105-
elif op == torch.ops.aten.conv2d.default:
106-
groups = args[6]
107-
transposed = False
108-
else:
244+
groups, transposed = DecomposeGroupedConvPass._get_groups_and_transposed(
245+
op, args
246+
)
247+
if groups is None:
109248
return super().call_operator(op, args, kwargs, meta)
110249

111-
if groups == 1 or transposed:
250+
if groups == 1:
112251
return super().call_operator(op, args, kwargs, meta)
113252

114253
input_node = args[0]
115-
if input_node.data.shape[1] == groups:
254+
if DecomposeGroupedConvPass._is_depthwise_conv(input_node, groups, transposed):
116255
# This is a depthwise convolution which is handled elsewhere
117256
return super().call_operator(op, args, kwargs, meta)
118257

119258
weight_node = args[1]
120259
bias_node = args[2]
121260

122-
input_slice_size = weight_node.data.shape[1]
123-
output_slice_size = weight_node.data.shape[0] // groups
261+
input_slice_size, output_slice_size = DecomposeGroupedConvPass._get_slice_sizes(
262+
weight_node, groups, transposed
263+
)
124264

125265
no_q_dq_meta = copy(meta)
126266
no_q_dq_meta.data = {}
127-
no_q_dq_meta.data = {}
128267

129268
slice_op, conv_op, cat_op = DecomposeGroupedConvPass._get_decomposition(op)
130269

131-
input_slices = []
132-
for i in range(groups):
133-
start_index = i * input_slice_size
134-
stop_index = (i + 1) * input_slice_size
135-
slice_args = (input_node, 1, start_index, stop_index)
136-
137-
input_slices.append(
138-
super().call_operator(
139-
slice_op, slice_args, kwargs, no_q_dq_meta, updated=True
140-
)
141-
)
142-
143-
filter_slices = []
144-
for i in range(groups):
145-
start_index = i * output_slice_size
146-
stop_index = (i + 1) * output_slice_size
147-
slice_args = (weight_node, 0, start_index, stop_index)
148-
149-
filter_slices.append(
150-
super().call_operator(
151-
slice_op, slice_args, kwargs, no_q_dq_meta, updated=True
152-
)
153-
)
154-
155-
bias_slices = []
156-
for i in range(groups):
157-
if bias_node is None:
158-
bias_slices.append(None)
159-
else:
160-
start_index = i * output_slice_size
161-
stop_index = (i + 1) * output_slice_size
162-
slice_args = (bias_node, 0, start_index, stop_index)
163-
164-
bias_slices.append(
165-
super().call_operator(
166-
slice_op, slice_args, kwargs, no_q_dq_meta, updated=True
167-
)
168-
)
270+
input_slices = self._slice_inputs(
271+
slice_op, input_node, input_slice_size, groups, no_q_dq_meta, kwargs
272+
)
273+
weight_slices = self._slice_weights(
274+
slice_op,
275+
weight_node,
276+
groups,
277+
input_slice_size,
278+
output_slice_size,
279+
transposed,
280+
no_q_dq_meta,
281+
kwargs,
282+
)
283+
bias_slices = self._slice_biases(
284+
slice_op, bias_node, groups, output_slice_size, no_q_dq_meta, kwargs
285+
)
169286

170287
output_slices = []
171288
for i, (input_slice, filter_slice, bias_slice) in enumerate(
172-
zip(input_slices, filter_slices, bias_slices)
289+
zip(input_slices, weight_slices, bias_slices)
173290
):
174291

175292
meta_copy = DecomposeGroupedConvPass._get_meta_copy(
176-
meta, i, output_slice_size
293+
meta,
294+
i,
295+
input_slice_size,
296+
output_slice_size,
297+
transposed,
177298
)
178299

179-
if op == exir_ops.edge.aten.convolution.default:
180-
conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1)
181-
elif op == torch.ops.aten.conv2d.default:
182-
conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1)
183-
else:
184-
raise RuntimeError("Invalid op for grouped conv decomposition")
300+
conv_args = DecomposeGroupedConvPass._build_conv_args(
301+
op, args, input_slice, filter_slice, bias_slice
302+
)
185303

186304
output_slices.append(
187305
super().call_operator(

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
266266
raise RuntimeError(
267267
"Only 2D transpose convolutions are supported in the Arm backend."
268268
)
269-
if group != 1:
270-
raise RuntimeError(
271-
"Grouped transpose convolutions are not supported in the Arm backend."
272-
)
273269
if any(d != 1 for d in dilation_list):
274270
raise RuntimeError(
275271
"Transpose convolutions with dilation are not supported in the Arm backend."

0 commit comments

Comments
 (0)