44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Optional
7+ from typing import cast , List , Optional
88
99import executorch .backends .vulkan .utils as utils
1010
@@ -33,12 +33,27 @@ def __init__(self, conv_node: torch.fx.Node) -> None:
3333 self .match_found = False
3434 self .all_nodes = [self .anchor_node ]
3535
36+ # Determine if this is a transposed convolution
37+ self .transposed = False
38+ self .output_padding = [0 , 0 ]
39+ if conv_node .target == exir_ops .edge .aten .convolution .default :
40+ transposed_flag = conv_node .args [6 ] if len (conv_node .args ) > 6 else False
41+ if transposed_flag :
42+ self .transposed = True
43+ self .output_padding = (
44+ cast (List [int ], conv_node .args [7 ]) if len (conv_node .args ) > 7 else [0 , 0 ]
45+ )
46+
3647 # Extract convolution parameters
3748 self .stride = conv_node .args [3 ] if len (conv_node .args ) > 3 else [1 , 1 ]
3849 self .padding = conv_node .args [4 ] if len (conv_node .args ) > 4 else [0 , 0 ]
3950 self .dilation = conv_node .args [5 ] if len (conv_node .args ) > 5 else [1 , 1 ]
4051 self .groups = conv_node .args [8 ] if len (conv_node .args ) > 8 else 1
4152
53+ # Transposed conv only supported with dilation=[1,1]
54+ if self .transposed and cast (List [int ], self .dilation ) != [1 , 1 ]:
55+ return
56+
4257 const_node , arg_chain = utils .trace_args_until_placeholder (
4358 self .anchor_node .args [1 ]
4459 )
@@ -60,6 +75,16 @@ def __init__(self, conv_node: torch.fx.Node) -> None:
6075 self .dequantize_weight_node = dequantize_weight_node
6176 self .all_nodes .extend (arg_chain )
6277
78+ # For transposed conv, verify per-channel quantization is on the OC dimension.
79+ # Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization
80+ # should be on axis=1. If axis=0, that's per-IC which is not supported.
81+ if self .transposed and utils .is_dequant_per_channel_node (
82+ self .dequantize_weight_node
83+ ):
84+ quant_axis = self .dequantize_weight_node .args [3 ]
85+ if quant_axis != 1 :
86+ return
87+
6388 # Identify weight quantization parameter nodes
6489 self .weight_scales_node , arg_chain = utils .trace_args_until_placeholder (
6590 self .dequantize_weight_node .args [1 ]
@@ -177,9 +202,30 @@ def make_q8ta_conv2d_custom_op(
177202 bias_tensor = get_param_tensor (ep , match .bias_node )
178203 assert bias_tensor is not None
179204
180- OC , IC_per_group , H , W = weight_tensor .shape
205+ if match .transposed :
206+ # Transposed conv weight shape: (IC, OC_per_group, H, W)
207+ IC , OC_per_group , H , W = weight_tensor .shape
208+ OC = OC_per_group * match .groups
209+ IC_per_group = IC // match .groups
210+ # Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based
211+ # transposed convolution.
212+ # (IC, OC_per_group, H, W) ->
213+ # (groups, IC_per_group, OC_per_group, H, W) ->
214+ # (groups, OC_per_group, H, W, IC_per_group) ->
215+ # (OC, H*W*IC_per_group)
216+ weight_tensor = (
217+ weight_tensor .reshape (match .groups , IC_per_group , OC_per_group , H , W )
218+ .permute (0 , 2 , 3 , 4 , 1 )
219+ .contiguous ()
220+ .reshape (OC , H * W * IC_per_group )
221+ .contiguous ()
222+ )
223+ else :
224+ OC , IC_per_group , H , W = weight_tensor .shape
181225
182- is_depthwise_conv = IC_per_group == 1 and match .groups == OC
226+ is_depthwise_conv = (
227+ not match .transposed and IC_per_group == 1 and match .groups == OC
228+ )
183229
184230 if is_depthwise_conv :
185231 assert OC % 4 == 0 , "depthwise conv requires that OC is divisible by 4"
@@ -188,7 +234,7 @@ def make_q8ta_conv2d_custom_op(
188234 weight_tensor = (
189235 weight_tensor .permute (2 , 3 , 1 , 0 ).contiguous ().view (H , W , OC ).contiguous ()
190236 )
191- else :
237+ elif not match . transposed :
192238 # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group)
193239 # (i.e. matrix format). This prepares the weights for Im2Col-based convolution.
194240 weight_tensor = (
@@ -257,32 +303,41 @@ def make_q8ta_conv2d_custom_op(
257303 )
258304
259305 with graph_module .graph .inserting_before (match .output_node ):
260- op_target = exir_ops .edge .et_vk .q8ta_conv2d .default
261- if is_depthwise_conv :
306+ if match .transposed :
307+ op_target = exir_ops .edge .et_vk .q8ta_conv2d_transposed .default
308+ elif is_depthwise_conv :
262309 op_target = exir_ops .edge .et_vk .q8ta_conv2d_dw .default
263310 elif is_pointwise_conv :
264311 op_target = exir_ops .edge .et_vk .q8ta_conv2d_pw .default
312+ else :
313+ op_target = exir_ops .edge .et_vk .q8ta_conv2d .default
314+
315+ op_args = (
316+ match .quantize_input_node ,
317+ match .input_scales_node ,
318+ match .input_zeros_node ,
319+ match .weight_node ,
320+ weight_sums_node ,
321+ match .weight_scales_node ,
322+ match .output_scales_node ,
323+ match .output_zeros_node ,
324+ match .bias_node ,
325+ [H , W ],
326+ match .stride ,
327+ match .padding ,
328+ )
329+ if match .transposed :
330+ op_args = op_args + (match .output_padding ,)
331+ op_args = op_args + (
332+ match .dilation ,
333+ match .groups ,
334+ "relu" if match .relu_node is not None else "none" ,
335+ )
265336
266337 qconv_node = graph_module .graph .create_node (
267338 "call_function" ,
268339 op_target ,
269- args = (
270- match .quantize_input_node ,
271- match .input_scales_node ,
272- match .input_zeros_node ,
273- match .weight_node ,
274- weight_sums_node ,
275- match .weight_scales_node ,
276- match .output_scales_node ,
277- match .output_zeros_node ,
278- match .bias_node , # Add bias after weight_scales
279- [H , W ], # Pass kernel size information before stride
280- match .stride ,
281- match .padding ,
282- match .dilation ,
283- match .groups ,
284- "relu" if match .relu_node is not None else "none" ,
285- ),
340+ args = op_args ,
286341 )
287342
288343 qconv_node .meta ["val" ] = match .output_node .meta ["val" ]
0 commit comments