44# LICENSE file in the root directory of this source tree.
55
66from copy import copy
7- from typing import Set , Type
7+ from typing import Literal , Protocol , Set , Type , TypeGuard
88
99import torch
1010from executorch .backends .arm ._passes .arm_pass import ArmPass
1414from executorch .exir .pass_base import ExportPass
1515
1616
17+ class _PerChannelQuantArgs (Protocol ):
18+ scale : list [float ]
19+ zp : list [int ]
20+ qmin : int
21+ qmax : int
22+ dtype : torch .dtype
23+ axis : int
24+ per_channel : Literal [True ]
25+
26+
1727class DecomposeGroupedConvPass (ArmPass ):
1828 """Splits a grouped convolution which is not supported by TOSA into multiple
1929 convolutions using slice->conv->cat.
@@ -47,6 +57,12 @@ def _get_decomposition(op):
4757 exir_ops .edge .aten .convolution .default ,
4858 exir_ops .edge .aten .cat .default ,
4959 )
60+ case torch .ops .aten .conv_transpose2d .input :
61+ return (
62+ torch .ops .aten .slice_copy .Tensor ,
63+ torch .ops .aten .conv_transpose2d .input ,
64+ torch .ops .aten .cat .default ,
65+ )
5066 case torch .ops .aten .conv2d .default :
5167 return (
5268 torch .ops .aten .slice_copy .Tensor ,
@@ -57,131 +73,233 @@ def _get_decomposition(op):
5773 raise RuntimeError ("Invalid op for grouped conv decomposition" )
5874
5975 @staticmethod
60- def _split_per_channel_qparams (qarg , index , output_slice_size ):
61- if qarg is not None and qarg .per_channel :
62- start_index = index * output_slice_size
63- stop_index = (index + 1 ) * output_slice_size
64- return QuantArgs (
65- scale = qarg .scale [start_index :stop_index ],
66- zp = qarg .zp [start_index :stop_index ],
67- qmin = qarg .qmin ,
68- qmax = qarg .qmax ,
69- dtype = qarg .dtype ,
70- axis = qarg .axis ,
71- per_channel = qarg .per_channel ,
76+ def _get_groups_and_transposed (op , args ):
77+ if op == exir_ops .edge .aten .convolution .default :
78+ return args [8 ], args [6 ]
79+ if op == torch .ops .aten .conv_transpose2d .input :
80+ return args [6 ], True
81+ if op == torch .ops .aten .conv2d .default :
82+ return args [6 ], False
83+ return None , None
84+
85+ @staticmethod
86+ def _is_depthwise_conv (input_node , groups , transposed ):
87+ return (not transposed ) and input_node .data .shape [1 ] == groups
88+
89+ @staticmethod
90+ def _get_slice_sizes (weight_node , groups , transposed ):
91+ if transposed :
92+ input_slice_size = weight_node .data .shape [0 ] // groups
93+ output_slice_size = weight_node .data .shape [1 ]
94+ else :
95+ input_slice_size = weight_node .data .shape [1 ]
96+ output_slice_size = weight_node .data .shape [0 ] // groups
97+ return input_slice_size , output_slice_size
98+
99+ def _slice_inputs (
100+ self , slice_op , input_node , input_slice_size , groups , meta , kwargs
101+ ):
102+ input_slices = []
103+ for i in range (groups ):
104+ start_index = i * input_slice_size
105+ stop_index = (i + 1 ) * input_slice_size
106+ slice_args = (input_node , 1 , start_index , stop_index )
107+ input_slices .append (
108+ super ().call_operator (slice_op , slice_args , kwargs , meta , updated = True )
72109 )
73- return qarg
110+ return input_slices
111+
112+ def _slice_weights (
113+ self ,
114+ slice_op ,
115+ weight_node ,
116+ groups ,
117+ input_slice_size ,
118+ output_slice_size ,
119+ transposed ,
120+ meta ,
121+ kwargs ,
122+ ):
123+ weight_slices = []
124+ for i in range (groups ):
125+ if transposed :
126+ start_index = i * input_slice_size
127+ stop_index = (i + 1 ) * input_slice_size
128+ else :
129+ start_index = i * output_slice_size
130+ stop_index = (i + 1 ) * output_slice_size
131+ slice_args = (weight_node , 0 , start_index , stop_index )
132+ weight_slices .append (
133+ super ().call_operator (slice_op , slice_args , kwargs , meta , updated = True )
134+ )
135+ return weight_slices
136+
137+ def _slice_biases (
138+ self , slice_op , bias_node , groups , output_slice_size , meta , kwargs
139+ ):
140+ bias_slices = []
141+ for i in range (groups ):
142+ if bias_node is None :
143+ bias_slices .append (None )
144+ continue
145+ start_index = i * output_slice_size
146+ stop_index = (i + 1 ) * output_slice_size
147+ slice_args = (bias_node , 0 , start_index , stop_index )
148+ bias_slices .append (
149+ super ().call_operator (slice_op , slice_args , kwargs , meta , updated = True )
150+ )
151+ return bias_slices
74152
75153 @staticmethod
76- def _get_meta_copy (meta , i , output_slice_size ):
154+ def _build_conv_args (op , args , input_slice , filter_slice , bias_slice ):
155+ if op == exir_ops .edge .aten .convolution .default :
156+ return (input_slice , filter_slice , bias_slice , * args [3 :8 ], 1 )
157+ if op == torch .ops .aten .conv_transpose2d .input :
158+ return (
159+ input_slice ,
160+ filter_slice ,
161+ bias_slice ,
162+ args [3 ],
163+ args [4 ],
164+ args [5 ],
165+ 1 ,
166+ args [7 ],
167+ )
168+ if op == torch .ops .aten .conv2d .default :
169+ return (input_slice , filter_slice , bias_slice , * args [3 :6 ], 1 )
170+ raise RuntimeError ("Invalid op for grouped conv decomposition" )
171+
172+ @staticmethod
173+ def _is_per_channel_qparams (
174+ qarg : QuantArgs | None ,
175+ ) -> TypeGuard [_PerChannelQuantArgs ]:
176+ return qarg is not None and qarg .per_channel
177+
178+ @staticmethod
179+ def _split_per_channel_qparams (
180+ qarg : _PerChannelQuantArgs , start_index , stop_index
181+ ) -> QuantArgs :
182+ return QuantArgs (
183+ scale = qarg .scale [start_index :stop_index ],
184+ zp = qarg .zp [start_index :stop_index ],
185+ qmin = qarg .qmin ,
186+ qmax = qarg .qmax ,
187+ dtype = qarg .dtype ,
188+ axis = qarg .axis ,
189+ per_channel = qarg .per_channel ,
190+ )
191+
192+ @staticmethod
193+ def _get_meta_copy (
194+ meta ,
195+ i ,
196+ input_slice_size ,
197+ output_slice_size ,
198+ transposed ,
199+ ):
77200 meta_copy = meta .copy ()
78201
79202 if "input_qparams" in meta .data and len (meta .data ["input_qparams" ]) > 0 :
80203 # Handle per-channel quantization by splitting quantization params
81204 # similarly to how activations/weights/biases are split.
82205 new_qparams = meta .data .get ("input_qparams" ).copy ()
206+
83207 # Get quantization params of the weights and slice them.
84208 w_qarg = new_qparams [1 ]
85- new_qparams [1 ] = DecomposeGroupedConvPass ._split_per_channel_qparams (
86- w_qarg , index = i , output_slice_size = output_slice_size
87- )
88- # Special case for int16, grouped conv2d when bias is included.
89- # As we add bias after in the DecomposeConv2dWithInt16ActivationPass we must
90- # also split the bias quantization parameters for bias.
91- if new_qparams [0 ].dtype == torch .int16 and len (new_qparams ) > 2 :
92- b_qarg = new_qparams [2 ]
93- new_qparams [2 ] = DecomposeGroupedConvPass ._split_per_channel_qparams (
94- b_qarg , index = i , output_slice_size = output_slice_size
209+ if DecomposeGroupedConvPass ._is_per_channel_qparams (w_qarg ):
210+
211+ # For transpose conv, axis=1 corresponds to output channels and
212+ # does not align with grouped slicing.
213+ # Per-channel quantization on axis=0 on the other hand could align here but
214+ # per-channel quant on axis 0 is very uncommon.
215+ if transposed :
216+ raise RuntimeError (
217+ "Grouped transpose conv with per-channel quantization is unsupported"
218+ )
219+
220+ slice_size = output_slice_size
221+ start_index = i * slice_size
222+ stop_index = (i + 1 ) * slice_size
223+ new_qparams [1 ] = DecomposeGroupedConvPass ._split_per_channel_qparams (
224+ w_qarg , start_index = start_index , stop_index = stop_index
95225 )
96226
227+ # Split per-channel bias qparams to match per-group output slices.
228+ if len (new_qparams ) > 2 :
229+ b_qarg = new_qparams [2 ]
230+ if DecomposeGroupedConvPass ._is_per_channel_qparams (b_qarg ):
231+ start_index = i * output_slice_size
232+ stop_index = (i + 1 ) * output_slice_size
233+ new_qparams [2 ] = (
234+ DecomposeGroupedConvPass ._split_per_channel_qparams (
235+ b_qarg , start_index = start_index , stop_index = stop_index
236+ )
237+ )
238+
97239 meta_copy .data ["input_qparams" ] = new_qparams
98240
99241 return meta_copy
100242
101243 def call_operator (self , op , args , kwargs , meta ):
102- if op == exir_ops .edge .aten .convolution .default :
103- groups = args [8 ]
104- transposed = args [6 ]
105- elif op == torch .ops .aten .conv2d .default :
106- groups = args [6 ]
107- transposed = False
108- else :
244+ groups , transposed = DecomposeGroupedConvPass ._get_groups_and_transposed (
245+ op , args
246+ )
247+ if groups is None :
109248 return super ().call_operator (op , args , kwargs , meta )
110249
111- if groups == 1 or transposed :
250+ if groups == 1 :
112251 return super ().call_operator (op , args , kwargs , meta )
113252
114253 input_node = args [0 ]
115- if input_node . data . shape [ 1 ] == groups :
254+ if DecomposeGroupedConvPass . _is_depthwise_conv ( input_node , groups , transposed ) :
116255 # This is a depthwise convolution which is handled elsewhere
117256 return super ().call_operator (op , args , kwargs , meta )
118257
119258 weight_node = args [1 ]
120259 bias_node = args [2 ]
121260
122- input_slice_size = weight_node .data .shape [1 ]
123- output_slice_size = weight_node .data .shape [0 ] // groups
261+ input_slice_size , output_slice_size = DecomposeGroupedConvPass ._get_slice_sizes (
262+ weight_node , groups , transposed
263+ )
124264
125265 no_q_dq_meta = copy (meta )
126266 no_q_dq_meta .data = {}
127- no_q_dq_meta .data = {}
128267
129268 slice_op , conv_op , cat_op = DecomposeGroupedConvPass ._get_decomposition (op )
130269
131- input_slices = []
132- for i in range (groups ):
133- start_index = i * input_slice_size
134- stop_index = (i + 1 ) * input_slice_size
135- slice_args = (input_node , 1 , start_index , stop_index )
136-
137- input_slices .append (
138- super ().call_operator (
139- slice_op , slice_args , kwargs , no_q_dq_meta , updated = True
140- )
141- )
142-
143- filter_slices = []
144- for i in range (groups ):
145- start_index = i * output_slice_size
146- stop_index = (i + 1 ) * output_slice_size
147- slice_args = (weight_node , 0 , start_index , stop_index )
148-
149- filter_slices .append (
150- super ().call_operator (
151- slice_op , slice_args , kwargs , no_q_dq_meta , updated = True
152- )
153- )
154-
155- bias_slices = []
156- for i in range (groups ):
157- if bias_node is None :
158- bias_slices .append (None )
159- else :
160- start_index = i * output_slice_size
161- stop_index = (i + 1 ) * output_slice_size
162- slice_args = (bias_node , 0 , start_index , stop_index )
163-
164- bias_slices .append (
165- super ().call_operator (
166- slice_op , slice_args , kwargs , no_q_dq_meta , updated = True
167- )
168- )
270+ input_slices = self ._slice_inputs (
271+ slice_op , input_node , input_slice_size , groups , no_q_dq_meta , kwargs
272+ )
273+ weight_slices = self ._slice_weights (
274+ slice_op ,
275+ weight_node ,
276+ groups ,
277+ input_slice_size ,
278+ output_slice_size ,
279+ transposed ,
280+ no_q_dq_meta ,
281+ kwargs ,
282+ )
283+ bias_slices = self ._slice_biases (
284+ slice_op , bias_node , groups , output_slice_size , no_q_dq_meta , kwargs
285+ )
169286
170287 output_slices = []
171288 for i , (input_slice , filter_slice , bias_slice ) in enumerate (
172- zip (input_slices , filter_slices , bias_slices )
289+ zip (input_slices , weight_slices , bias_slices )
173290 ):
174291
175292 meta_copy = DecomposeGroupedConvPass ._get_meta_copy (
176- meta , i , output_slice_size
293+ meta ,
294+ i ,
295+ input_slice_size ,
296+ output_slice_size ,
297+ transposed ,
177298 )
178299
179- if op == exir_ops .edge .aten .convolution .default :
180- conv_args = (input_slice , filter_slice , bias_slice , * args [3 :8 ], 1 )
181- elif op == torch .ops .aten .conv2d .default :
182- conv_args = (input_slice , filter_slice , bias_slice , * args [3 :6 ], 1 )
183- else :
184- raise RuntimeError ("Invalid op for grouped conv decomposition" )
300+ conv_args = DecomposeGroupedConvPass ._build_conv_args (
301+ op , args , input_slice , filter_slice , bias_slice
302+ )
185303
186304 output_slices .append (
187305 super ().call_operator (
0 commit comments