Skip to content

Commit 918e92b

Browse files
pytorchbotssjia
andauthored
[ET-VK][ez] Fix duplicate placeholder target in create_constant_placeholder (pytorch#18031)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#18013 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/460/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/460/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/460/orig Differential Revision: [D95807071](https://our.internmc.facebook.com/intern/diff/D95807071/) @diff-train-skip-merge --------- Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
1 parent 518daa8 commit 918e92b

22 files changed

Lines changed: 1911 additions & 47 deletions

backends/transforms/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ def create_constant_placeholder(
111111

112112
target = name
113113

114+
# If a placeholder with this target already exists, return it to avoid
115+
# duplicate parameter names in the generated function signature which would
116+
# cause a SyntaxError on recompile. This can happen when multiple pattern
117+
# replacements independently create placeholders for a shared weight.
118+
if name in exp_program.state_dict or name in exp_program.constants:
119+
for n in graph.nodes:
120+
if n.op == "placeholder" and n.target == name:
121+
return n
122+
114123
# Add data to state_dict/ constants
115124
match kind:
116125
case InputKind.PARAMETER:

backends/vulkan/custom_ops_lib.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,105 @@ def q8ta_conv2d_dw(
685685
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
686686
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)
687687

688+
689+
def q8ta_conv2d_transposed(
690+
x: torch.Tensor,
691+
input_scale: float,
692+
input_zero_point: int,
693+
weights: torch.Tensor,
694+
weight_sums: torch.Tensor,
695+
weight_scales: torch.Tensor,
696+
output_scale: float,
697+
output_zero_point: int,
698+
bias: Optional[torch.Tensor],
699+
kernel_size: list,
700+
stride: list,
701+
padding: list,
702+
output_padding: list,
703+
dilation: list,
704+
groups: int,
705+
activation: str,
706+
):
707+
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
708+
x, input_scale, input_zero_point, -128, 127, x.dtype
709+
)
710+
711+
OC = weights.shape[0]
712+
IC_per_group = int(x.shape[1] / groups)
713+
K_h, K_w = kernel_size[0], kernel_size[1]
714+
715+
orig_weight_K_dim = K_h * K_w * IC_per_group
716+
if weights.shape[-1] > orig_weight_K_dim:
717+
weights = weights[:, :orig_weight_K_dim]
718+
719+
if weight_scales.shape[0] > OC:
720+
weight_scales = weight_scales[:OC]
721+
if bias is not None:
722+
bias = bias[:OC]
723+
724+
# Reshape to (OC, IC_per_group, K_h, K_w) then transpose to
725+
# (IC_per_group * groups, OC_per_group, K_h, K_w) for conv_transpose2d
726+
weights = weights.view(OC, IC_per_group, K_h, K_w)
727+
OC_per_group = OC // groups
728+
weights = (
729+
weights.view(groups, OC_per_group, IC_per_group, K_h, K_w)
730+
.permute(0, 2, 1, 3, 4)
731+
.contiguous()
732+
.view(IC_per_group * groups, OC_per_group, K_h, K_w)
733+
)
734+
735+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
736+
# Dequantize per OC channel. For transposed weight (IC, OC_per_group, KH, KW),
737+
# OC is at axis=1.
738+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
739+
weights,
740+
weight_scales[:OC_per_group].repeat(groups) if groups > 1 else weight_scales,
741+
weight_zeros[:OC_per_group].repeat(groups) if groups > 1 else weight_zeros,
742+
1,
743+
-127,
744+
127,
745+
torch.int8,
746+
)
747+
748+
out = torch.nn.functional.conv_transpose2d(
749+
x, weights, bias, stride, padding, output_padding, groups, dilation
750+
)
751+
752+
if activation == "relu":
753+
out = torch.nn.functional.relu(out)
754+
755+
out = torch.ops.quantized_decomposed.quantize_per_tensor(
756+
out, output_scale, output_zero_point, -128, 127, torch.int8
757+
)
758+
759+
return out
760+
761+
762+
name = "q8ta_conv2d_transposed"
763+
lib.define(
764+
f"""
765+
{name}(
766+
Tensor x,
767+
float input_scale,
768+
int input_zero_point,
769+
Tensor weights,
770+
Tensor weight_sums,
771+
Tensor weight_scales,
772+
float output_scale,
773+
int output_zero_point,
774+
Tensor? bias,
775+
SymInt[] kernel_size,
776+
SymInt[] stride,
777+
SymInt[] padding,
778+
SymInt[] output_padding,
779+
SymInt[] dilation,
780+
SymInt groups,
781+
str activation) -> Tensor
782+
"""
783+
)
784+
lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd")
785+
q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name)
786+
688787
######################
689788
## apply_rotary_emb ##
690789
######################

backends/vulkan/op_registry.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def register_quantize_per_tensor():
482482
outputs_storage=[
483483
utils.PACKED_INT8_BUFFER,
484484
],
485+
supports_highdim=True,
485486
)
486487

487488

@@ -499,6 +500,7 @@ def register_dequantize_per_tensor():
499500
outputs_storage=[
500501
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
501502
],
503+
supports_highdim=True,
502504
)
503505

504506

@@ -863,6 +865,39 @@ def register_q8ta_conv2d_ops():
863865
)
864866

865867

868+
@update_features(
869+
[
870+
exir_ops.edge.et_vk.q8ta_conv2d_transposed.default,
871+
]
872+
)
873+
def register_q8ta_conv2d_transposed_op():
874+
return OpFeatures(
875+
inputs_storage=[
876+
utils.PACKED_INT8_CONV2D_BUFFER, # input
877+
utils.NO_STORAGE, # input_scale (non tensor)
878+
utils.NO_STORAGE, # input_zero_point (non tensor)
879+
utils.NO_STORAGE, # weight (prepacked)
880+
utils.NO_STORAGE, # weight_sums (prepacked)
881+
utils.NO_STORAGE, # weight_scales (prepacked)
882+
utils.NO_STORAGE, # output_scale (non tensor)
883+
utils.NO_STORAGE, # output_zero_point (non tensor)
884+
utils.NO_STORAGE, # bias (prepacked)
885+
utils.NO_STORAGE, # kernel_size (non tensor)
886+
utils.NO_STORAGE, # stride (non tensor)
887+
utils.NO_STORAGE, # padding (non tensor)
888+
utils.NO_STORAGE, # output_padding (non tensor)
889+
utils.NO_STORAGE, # dilation (non tensor)
890+
utils.NO_STORAGE, # groups (non tensor)
891+
utils.NO_STORAGE, # activation (non tensor)
892+
],
893+
outputs_storage=[
894+
utils.PACKED_INT8_CHANNELS_PACKED_BUFFER,
895+
],
896+
supports_resize=False,
897+
supports_prepacking=True,
898+
)
899+
900+
866901
# =============================================================================
867902
# Q8taLinear.cpp
868903
# =============================================================================

backends/vulkan/patterns/quantized_convolution.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import cast, List, Optional
88

99
import executorch.backends.vulkan.utils as utils
1010

@@ -33,12 +33,27 @@ def __init__(self, conv_node: torch.fx.Node) -> None:
3333
self.match_found = False
3434
self.all_nodes = [self.anchor_node]
3535

36+
# Determine if this is a transposed convolution
37+
self.transposed = False
38+
self.output_padding = [0, 0]
39+
if conv_node.target == exir_ops.edge.aten.convolution.default:
40+
transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False
41+
if transposed_flag:
42+
self.transposed = True
43+
self.output_padding = (
44+
cast(List[int], conv_node.args[7]) if len(conv_node.args) > 7 else [0, 0]
45+
)
46+
3647
# Extract convolution parameters
3748
self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1]
3849
self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0]
3950
self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1]
4051
self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1
4152

53+
# Transposed conv only supported with dilation=[1,1]
54+
if self.transposed and cast(List[int], self.dilation) != [1, 1]:
55+
return
56+
4257
const_node, arg_chain = utils.trace_args_until_placeholder(
4358
self.anchor_node.args[1]
4459
)
@@ -60,6 +75,16 @@ def __init__(self, conv_node: torch.fx.Node) -> None:
6075
self.dequantize_weight_node = dequantize_weight_node
6176
self.all_nodes.extend(arg_chain)
6277

78+
# For transposed conv, verify per-channel quantization is on the OC dimension.
79+
# Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization
80+
# should be on axis=1. If axis=0, that's per-IC which is not supported.
81+
if self.transposed and utils.is_dequant_per_channel_node(
82+
self.dequantize_weight_node
83+
):
84+
quant_axis = self.dequantize_weight_node.args[3]
85+
if quant_axis != 1:
86+
return
87+
6388
# Identify weight quantization parameter nodes
6489
self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder(
6590
self.dequantize_weight_node.args[1]
@@ -177,9 +202,30 @@ def make_q8ta_conv2d_custom_op(
177202
bias_tensor = get_param_tensor(ep, match.bias_node)
178203
assert bias_tensor is not None
179204

180-
OC, IC_per_group, H, W = weight_tensor.shape
205+
if match.transposed:
206+
# Transposed conv weight shape: (IC, OC_per_group, H, W)
207+
IC, OC_per_group, H, W = weight_tensor.shape
208+
OC = OC_per_group * match.groups
209+
IC_per_group = IC // match.groups
210+
# Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based
211+
# transposed convolution.
212+
# (IC, OC_per_group, H, W) ->
213+
# (groups, IC_per_group, OC_per_group, H, W) ->
214+
# (groups, OC_per_group, H, W, IC_per_group) ->
215+
# (OC, H*W*IC_per_group)
216+
weight_tensor = (
217+
weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W)
218+
.permute(0, 2, 3, 4, 1)
219+
.contiguous()
220+
.reshape(OC, H * W * IC_per_group)
221+
.contiguous()
222+
)
223+
else:
224+
OC, IC_per_group, H, W = weight_tensor.shape
181225

182-
is_depthwise_conv = IC_per_group == 1 and match.groups == OC
226+
is_depthwise_conv = (
227+
not match.transposed and IC_per_group == 1 and match.groups == OC
228+
)
183229

184230
if is_depthwise_conv:
185231
assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4"
@@ -188,7 +234,7 @@ def make_q8ta_conv2d_custom_op(
188234
weight_tensor = (
189235
weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous()
190236
)
191-
else:
237+
elif not match.transposed:
192238
# Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group)
193239
# (i.e. matrix format). This prepares the weights for Im2Col-based convolution.
194240
weight_tensor = (
@@ -257,32 +303,41 @@ def make_q8ta_conv2d_custom_op(
257303
)
258304

259305
with graph_module.graph.inserting_before(match.output_node):
260-
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default
261-
if is_depthwise_conv:
306+
if match.transposed:
307+
op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default
308+
elif is_depthwise_conv:
262309
op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default
263310
elif is_pointwise_conv:
264311
op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default
312+
else:
313+
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default
314+
315+
op_args = (
316+
match.quantize_input_node,
317+
match.input_scales_node,
318+
match.input_zeros_node,
319+
match.weight_node,
320+
weight_sums_node,
321+
match.weight_scales_node,
322+
match.output_scales_node,
323+
match.output_zeros_node,
324+
match.bias_node,
325+
[H, W],
326+
match.stride,
327+
match.padding,
328+
)
329+
if match.transposed:
330+
op_args = op_args + (match.output_padding,)
331+
op_args = op_args + (
332+
match.dilation,
333+
match.groups,
334+
"relu" if match.relu_node is not None else "none",
335+
)
265336

266337
qconv_node = graph_module.graph.create_node(
267338
"call_function",
268339
op_target,
269-
args=(
270-
match.quantize_input_node,
271-
match.input_scales_node,
272-
match.input_zeros_node,
273-
match.weight_node,
274-
weight_sums_node,
275-
match.weight_scales_node,
276-
match.output_scales_node,
277-
match.output_zeros_node,
278-
match.bias_node, # Add bias after weight_scales
279-
[H, W], # Pass kernel size information before stride
280-
match.stride,
281-
match.padding,
282-
match.dilation,
283-
match.groups,
284-
"relu" if match.relu_node is not None else "none",
285-
),
340+
args=op_args,
286341
)
287342

288343
qconv_node.meta["val"] = match.output_node.meta["val"]

backends/vulkan/patterns/quantized_linear.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
9090
# Identify output node
9191
self.output_node = self.anchor_node
9292

93+
# bmm with batch dim > 1 is not supported
94+
is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default
95+
if is_bmm and self.output_node.meta["val"].shape[0] != 1:
96+
return
97+
9398
# Identify primary input node of the anchor. Due to decomposition of aten.linear
9499
# there may be a view_copy node between the original input tensor to the linear
95100
# op and the actual linear op node.
@@ -174,12 +179,21 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
174179

175180
# Check if the output is also quantized (q → dq → linear → q pattern)
176181
# Also handle fused linear+relu (q → dq → linear → relu → q pattern)
182+
# Due to decomposition of aten.linear for 3D+ inputs, there may be a
183+
# view_copy between the mm output and the quantize node.
177184
self.quantize_output_node = None
178185
self.output_scales_node = None
179186
self.output_zeros_node = None
180187
self.relu_node = None
188+
self.output_view_copy_node = None
181189
if len(self.output_node.users) == 1:
182190
cur_node = list(self.output_node.users)[0]
191+
# Skip potential view_copy between linear and output quantize
192+
if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1:
193+
self.output_view_copy_node = cur_node
194+
self.all_nodes.append(self.output_view_copy_node)
195+
self.output_node = self.output_view_copy_node
196+
cur_node = list(cur_node.users)[0]
183197
if cur_node.target == exir_ops.edge.aten.relu.default:
184198
self.relu_node = cur_node
185199
if len(cur_node.users) == 1:
@@ -259,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool:
259273
exir_ops.edge.aten.linear.default,
260274
exir_ops.edge.aten.mm.default,
261275
exir_ops.edge.aten.addmm.default,
276+
exir_ops.edge.aten.bmm.default,
262277
}
263278

264279

0 commit comments

Comments
 (0)