Skip to content

Commit ee0ca9c

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][conv2d_dw] Extract depthwise dispatch into Conv2dDW.cpp with device-based tile selection
Pull Request resolved: #18293 Profiling showed depthwise conv2d is 5-15x slower on Mali GPUs vs Adreno due to register pressure from the 4x2 output tile (17 vec4 registers per thread). Benchmarking confirmed that reducing the tile to 1x1 (7 vec4 registers) gives 4-15x speedup on Mali with no regression on Adreno. This change extracts depthwise conv2d dispatch logic from Convolution.cpp into a new Conv2dDW.cpp (following the Conv2dPW.cpp pattern), and adds device-based tile size selection: b1x1 on Mali, b4x2 (current default) on Adreno. ghstack-source-id: 353940602 @exported-using-ghexport Differential Revision: [D97058158](https://our.internmc.facebook.com/intern/diff/D97058158/)
1 parent 542f6d4 commit ee0ca9c

7 files changed

Lines changed: 1158 additions & 152 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ conv2d_dw_output_tile:
2525
- NAME: conv2d_dw_output_tile_5x5_clamp
2626
OPERATOR: clamp(X, A, B)
2727
TILE_SIZE: 5
28+
- NAME: conv2d_dw_output_tile_3x3_b1x1
29+
BATCH_SIZE_X: 1
30+
BATCH_SIZE_Y: 1
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
17+
18+
namespace vkcompute {
19+
20+
//
21+
// Weight prepack
22+
//
23+
24+
ValueRef prepack_dw_weights(ComputeGraph& graph, const ValueRef vref) {
25+
const auto original_sizes = graph.sizes_of(vref);
26+
27+
int64_t out_channels_padded =
28+
utils::align_up_4(utils::val_at(-4, original_sizes));
29+
int64_t height = utils::val_at(-2, original_sizes);
30+
int64_t width = utils::val_at(-1, original_sizes);
31+
32+
const std::vector<int64_t> final_sizes = {
33+
4, out_channels_padded / 4, height * width};
34+
35+
ValueRef v = graph.add_tensor(
36+
final_sizes,
37+
graph.dtype_of(vref),
38+
utils::kTexture2D,
39+
utils::kChannelsPacked);
40+
41+
std::string kernel_name = "conv2d_dw_prepack_weights";
42+
add_dtype_suffix(kernel_name, graph.dtype_of(v));
43+
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(vref));
44+
45+
const auto original_sizes_pc =
46+
utils::make_ivec4(original_sizes, /*reverse = */ true);
47+
graph.prepack_nodes().emplace_back(new PrepackNode(
48+
graph,
49+
VK_KERNEL_FROM_STR(kernel_name),
50+
graph.create_global_wg_size(v),
51+
graph.create_local_wg_size(v),
52+
vref,
53+
v,
54+
{},
55+
// Specialization constants
56+
{graph.packed_dim_of(v)},
57+
{graph.sizes_pc_of(v),
58+
PushConstantDataInfo(&original_sizes_pc, sizeof(original_sizes_pc))}));
59+
60+
return v;
61+
}
62+
63+
//
64+
// Shader selection
65+
//
66+
67+
std::string pick_conv2d_dw_shader(
68+
ComputeGraph& graph,
69+
const ValueRef weight_data,
70+
const ValueRef out,
71+
const bool stride_equals_dilation,
72+
const bool clamp_out) {
73+
std::string kernel_name = "conv2d_dw";
74+
kernel_name.reserve(kShaderNameReserve);
75+
76+
const auto& weight_sizes = graph.get_tref(weight_data)->sizes;
77+
const bool is_3x3 = weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3;
78+
const bool is_5x5 = weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5;
79+
80+
if (!stride_equals_dilation) {
81+
kernel_name += "_sned";
82+
}
83+
84+
if (is_3x3) {
85+
kernel_name += "_output_tile_3x3";
86+
if (stride_equals_dilation && graph.device_is_mali()) {
87+
kernel_name += "_b1x1";
88+
}
89+
} else if (is_5x5) {
90+
kernel_name += "_output_tile_5x5";
91+
}
92+
93+
if (clamp_out) {
94+
kernel_name += "_clamp";
95+
}
96+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
97+
98+
return kernel_name;
99+
}
100+
101+
//
102+
// Workgroup size
103+
//
104+
105+
utils::uvec3 conv2d_dw_global_wg_size(
106+
ComputeGraph* graph,
107+
const vkapi::ShaderInfo& shader,
108+
const std::vector<ArgGroup>& args,
109+
const std::vector<ValueRef>& resize_args) {
110+
(void)resize_args;
111+
const ValueRef out = args.at(0).refs.at(0);
112+
113+
const bool uses_output_tile =
114+
shader.kernel_name.find("_output_tile") != std::string::npos;
115+
116+
if (uses_output_tile) {
117+
const bool is_sned = shader.kernel_name.find("_sned") != std::string::npos;
118+
119+
const utils::uvec3 image_extents = graph->create_global_wg_size(out);
120+
121+
if (is_sned) {
122+
// sned output_tile shaders: no batch division, just flatten W*H
123+
return {image_extents[0] * image_extents[1], image_extents[2], 1};
124+
}
125+
126+
// stride==dilation output_tile shaders: apply batch division
127+
uint32_t batch_x = 4u;
128+
uint32_t batch_y = 2u;
129+
if (shader.kernel_name.find("_b1x1") != std::string::npos) {
130+
batch_x = 1u;
131+
batch_y = 1u;
132+
}
133+
134+
uint32_t scaled_x = utils::div_up(image_extents[0], batch_x);
135+
uint32_t scaled_y = utils::div_up(image_extents[1], batch_y);
136+
return {scaled_x * scaled_y, image_extents[2], 1};
137+
}
138+
139+
// Base conv2d_dw shader: fully linearized dispatch
140+
const utils::uvec3 base_extents = graph->create_global_wg_size(out);
141+
return {base_extents[0] * base_extents[1] * base_extents[2], 1, 1};
142+
}
143+
144+
utils::uvec3 conv2d_dw_local_wg_size(
145+
ComputeGraph* graph,
146+
const vkapi::ShaderInfo& shader,
147+
const utils::uvec3& global_workgroup_size,
148+
const std::vector<ArgGroup>& args,
149+
const std::vector<ValueRef>& resize_args) {
150+
(void)graph;
151+
(void)shader;
152+
(void)global_workgroup_size;
153+
(void)args;
154+
(void)resize_args;
155+
return {64, 1, 1};
156+
}
157+
158+
//
159+
// Dispatch node
160+
//
161+
162+
struct Conv2dDWParams final {
163+
utils::ivec2 overlay_region;
164+
int in_group_size;
165+
};
166+
167+
struct OutputParams final {
168+
float out_min;
169+
float out_max;
170+
};
171+
172+
void add_conv2d_dw_node(
173+
ComputeGraph& graph,
174+
const ValueRef in,
175+
const ValueRef arg_weight,
176+
const ValueRef arg_bias,
177+
const ValueRef weight_data,
178+
const ValueRef stride,
179+
const ValueRef padding,
180+
const ValueRef dilation,
181+
const ValueRef out,
182+
const std::string& kernel_name,
183+
const Kernel2dParams& kernel_params,
184+
const Conv2dDWParams& extra_params,
185+
const OutputParams& out_params) {
186+
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
187+
188+
vkapi::ParamsBindList param_buffers;
189+
std::vector<PushConstantDataInfo> push_constants;
190+
191+
const bool uses_output_tile =
192+
kernel_name.find("_output_tile") != std::string::npos;
193+
194+
if (uses_output_tile) {
195+
const utils::ivec4 kernel_param_size_stride = {
196+
kernel_params.kernel_size[0],
197+
kernel_params.kernel_size[1],
198+
kernel_params.stride[0],
199+
kernel_params.stride[1]};
200+
201+
const utils::ivec4 kernel_param_pad_dial = {
202+
kernel_params.padding[0],
203+
kernel_params.padding[1],
204+
kernel_params.dilation[0],
205+
kernel_params.dilation[1]};
206+
207+
push_constants = {
208+
graph.logical_limits_pc_of(out),
209+
graph.sizes_pc_of(in),
210+
PushConstantDataInfo(
211+
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
212+
PushConstantDataInfo(
213+
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
214+
PushConstantDataInfo(
215+
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
216+
PushConstantDataInfo(&out_params, sizeof(out_params)),
217+
};
218+
} else {
219+
param_buffers = {
220+
graph.logical_limits_ubo(out),
221+
graph.sizes_ubo(in),
222+
graph.create_params_buffer(kernel_params),
223+
graph.create_params_buffer(extra_params),
224+
graph.create_params_buffer(out_params),
225+
};
226+
}
227+
228+
// transposed is always false for depthwise, output_padding unused
229+
ValueRef transposed_ref = graph.add_scalar(false);
230+
ValueRef output_padding = graph.add_none();
231+
232+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
233+
graph,
234+
shader,
235+
conv2d_dw_global_wg_size,
236+
conv2d_dw_local_wg_size,
237+
// Inputs and Outputs
238+
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
239+
// Shader params buffers
240+
param_buffers,
241+
// Push Constants
242+
push_constants,
243+
// Specialization Constants
244+
{},
245+
// Resize Args
246+
{weight_data, stride, padding, dilation, transposed_ref, output_padding},
247+
// Resizing Logic
248+
resize_conv2d_node));
249+
}
250+
251+
//
252+
// High level operator impl
253+
//
254+
255+
void conv2d_dw_impl(
256+
ComputeGraph& graph,
257+
const ValueRef in,
258+
const ValueRef weight_data,
259+
const ValueRef bias,
260+
const ValueRef stride,
261+
const ValueRef padding,
262+
const ValueRef dilation,
263+
const ValueRef out,
264+
const bool clamp_out,
265+
const float out_min_val,
266+
const float out_max_val) {
267+
ValueRef arg_weight = prepack_dw_weights(graph, weight_data);
268+
ValueRef arg_bias = prepack_biases(
269+
graph,
270+
bias,
271+
weight_data,
272+
/* transposed = */ false,
273+
/* storage_type = */ utils::kTexture2D,
274+
/* memory_layout = */ utils::kWidthPacked);
275+
276+
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
277+
if (in_sizes.at(0) > 1) {
278+
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
279+
}
280+
281+
check_conv_args(graph, in, out);
282+
283+
Kernel2dParams kernel_params = create_kernel2d_params(
284+
graph,
285+
weight_data,
286+
/*kernel_size_only = */ false,
287+
stride,
288+
padding,
289+
dilation);
290+
291+
const bool stride_equals_dilation =
292+
(kernel_params.stride[0] == kernel_params.dilation[0] &&
293+
kernel_params.stride[1] == kernel_params.dilation[1]);
294+
295+
const auto& overlay_region = utils::make_ivec2({
296+
kernel_params.kernel_size[0] +
297+
(kernel_params.kernel_size[0] - 1) * (kernel_params.dilation[0] - 1),
298+
kernel_params.kernel_size[1] +
299+
(kernel_params.kernel_size[1] - 1) * (kernel_params.dilation[1] - 1),
300+
});
301+
const auto weight_sizes = graph.sizes_of(weight_data);
302+
const int32_t in_group_size =
303+
utils::safe_downcast<int32_t>(utils::align_up_4(weight_sizes.at(1)));
304+
Conv2dDWParams extra_params = {overlay_region, in_group_size};
305+
306+
OutputParams out_params = {out_min_val, out_max_val};
307+
308+
std::string kernel_name = pick_conv2d_dw_shader(
309+
graph, weight_data, out, stride_equals_dilation, clamp_out);
310+
311+
add_conv2d_dw_node(
312+
graph,
313+
in,
314+
arg_weight,
315+
arg_bias,
316+
weight_data,
317+
stride,
318+
padding,
319+
dilation,
320+
out,
321+
kernel_name,
322+
kernel_params,
323+
extra_params,
324+
out_params);
325+
}
326+
327+
} // namespace vkcompute

0 commit comments

Comments
 (0)