Skip to content

Commit 936f6d4

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][conv1d] Implement height-packed pointwise conv1d operator
Pull Request resolved: pytorch#18332 Implement a new conv1d pointwise (kernel_size=1) operator using height-packed layout where channels are the packed dimension (WHCN dim 1). This enables dot-product reduction over input channels: each vec4 load gives 4 consecutive channel values, yielding 4 MACs per dot() instruction. Uses tiled computation with the FP tile infrastructure from linear/matmul (FPInputTile, FPWeightTile, FPOutTile, fp_accumulate_with_fp_weight) and 4OC×4IC blocked weight packing via pack_fp_linear_weight.glsl for cache-friendly texture2d weight reads. Adaptive tile_m selection (4/2/1 rows) based on GPU occupancy. Thread mapping: X=OC4 tiles, Y=L tiles, Z=batch. Each thread computes TILE_M×TILE_N4×4 output elements. Inner loop loads input tiles and packed weight tiles, then calls fp_accumulate_with_fp_weight for tiled FMA. Supports both buffer and texture3d storage for input/output, texture2d or buffer for packed weights, fp32/fp16, and optional bias. Registered as et_vk.conv1d_pw.default (standalone custom op for testing/benchmarking). Performance on Adreno 750 (S24): - [1,256,1024]x[512,256,1] texture f16: 908 GFLOP/s - [1,512,2048]x[256,512,1] texture f16: 865 GFLOP/s - [1,128,4096]x[128,128,1] texture f16: 781 GFLOP/s - [1,256,1024]x[512,256,1] buffer f16: 491 GFLOP/s ghstack-source-id: 358903218 @exported-using-ghexport Differential Revision: [D97344092](https://our.internmc.facebook.com/intern/diff/D97344092/)
1 parent 6fccd5a commit 936f6d4

6 files changed

Lines changed: 838 additions & 0 deletions

File tree

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
14+
15+
$if STORAGE == "buffer":
16+
#define OUTPUT_BUFFER
17+
#define INPUT_BUFFER
18+
#define SCALAR_BUFFER
19+
$if WEIGHT_STORAGE == "buffer":
20+
#define WEIGHT_BUFFER
21+
$if HAS_BIAS:
22+
#define HAS_BIAS
23+
$if STORAGE == "buffer" and HAS_BIAS:
24+
#define BIAS_BUFFER
25+
26+
#define TILE_M4 ${TILE_M4}
27+
#define TILE_K4 ${TILE_K4}
28+
#define TILE_N4 ${TILE_N4}
29+
30+
#define TILE_M ${TILE_M}
31+
#define TILE_K ${TILE_K4 * 4}
32+
#define TILE_N ${TILE_N4 * 4}
33+
34+
${define_required_extensions(STORAGE, DTYPE)}
35+
$if WEIGHT_STORAGE != STORAGE:
36+
${define_required_extensions(WEIGHT_STORAGE, DTYPE)}
37+
38+
layout(std430) buffer;
39+
40+
#include "common.glslh"
41+
42+
$if STORAGE == "buffer":
43+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=True)}
44+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=True)}
45+
$else:
46+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
47+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
48+
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, WEIGHT_STORAGE, is_scalar_array=False)}
49+
$if HAS_BIAS:
50+
$if STORAGE == "buffer":
51+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)}
52+
$else:
53+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)}
54+
55+
// in_sizes: {L, C_in, N, 1} in WHCN order
56+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
57+
// out_sizes: {L, C_out, N, 1} in WHCN order
58+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
59+
$if HAS_BIAS:
60+
${layout_declare_ubo(B, "ivec4", "bias_sizes")}
61+
62+
layout(push_constant) uniform restrict Block {
63+
int weight_B;
64+
float output_min;
65+
float output_max;
66+
};
67+
68+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
69+
70+
#include "linear_fp_input_tile.glslh"
71+
#include "linear_fp_weight_tile.glslh"
72+
#include "linear_fp_output_tile.glslh"
73+
#include "linear_fp_packed_weight_tile_load.glslh"
74+
#include "linear_fp_output_tile_fp_compute.glslh"
75+
76+
// Conv1d pointwise is matrix multiplication with swapped texture coordinates.
77+
// Linear: input ivec3(k4, m, b), output ivec3(n4, m, b) [width-packed]
78+
// Conv1d: input ivec3(m, k4, b), output ivec3(m, n4, b) [height-packed]
79+
//
80+
// For buffer storage, height-packed tensors have packed_dim_block_size=1 (no
81+
// vec4 grouping). Data is stored as contiguous scalars with strides based on
82+
// logical sizes, so scalar indexing is required: (b * M + m) * C + c.
83+
// For texture storage, 4 channels are packed per texel as usual.
84+
85+
#ifndef SCALAR_BUFFER
86+
VEC4_T load_input_x4(
87+
const int k4,
88+
const int m,
89+
const int b,
90+
const int K4,
91+
const int M) {
92+
#ifdef INPUT_BUFFER
93+
return t_in[(b * M + m) * K4 + k4];
94+
#else
95+
return texelFetch(t_in, ivec3(m, k4, b), 0);
96+
#endif
97+
}
98+
99+
void load_input_tile_with_checks(
100+
out FPInputTile tile,
101+
const int k4_start,
102+
const int m_start,
103+
const int b,
104+
const int K4,
105+
const int M) {
106+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
107+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
108+
if (k4_start + k4 < K4 && m_start + m < M) {
109+
tile.data[m][k4] =
110+
load_input_x4(k4_start + k4, m_start + m, b, K4, M);
111+
} else {
112+
tile.data[m][k4] = VEC4_T(0.0);
113+
}
114+
}
115+
}
116+
}
117+
118+
void store_output_x4(
119+
const VEC4_T texel,
120+
const int n4,
121+
const int m,
122+
const int b,
123+
const int N4,
124+
const int M) {
125+
#ifdef OUTPUT_BUFFER
126+
t_out[(b * M + m) * N4 + n4] = texel;
127+
#else
128+
imageStore(t_out, ivec3(m, n4, b), texel);
129+
#endif
130+
}
131+
132+
void store_output_tile_with_checks(
133+
const FPOutTile out_tile,
134+
const int n4_start,
135+
const int m_start,
136+
const int b,
137+
const int N4,
138+
const int M) {
139+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
140+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
141+
if (m_start + m < M && n4_start + n4 < N4) {
142+
store_output_x4(
143+
out_tile.data[m][n4], n4_start + n4, m_start + m, b, N4, M);
144+
}
145+
}
146+
}
147+
}
148+
#endif // !SCALAR_BUFFER
149+
150+
#ifdef SCALAR_BUFFER
151+
void load_input_tile_scalar(
152+
out FPInputTile tile,
153+
const int k4_start,
154+
const int m_start,
155+
const int b,
156+
const int K4,
157+
const int K,
158+
const int M) {
159+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
160+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
161+
if (k4_start + k4 < K4 && m_start + m < M) {
162+
const int base = (b * M + m_start + m) * K + mul_4(k4_start + k4);
163+
T s0 = t_in[base];
164+
T s1 = (mul_4(k4_start + k4) + 1 < K) ? t_in[base + 1] : T(0);
165+
T s2 = (mul_4(k4_start + k4) + 2 < K) ? t_in[base + 2] : T(0);
166+
T s3 = (mul_4(k4_start + k4) + 3 < K) ? t_in[base + 3] : T(0);
167+
tile.data[m][k4] = VEC4_T(s0, s1, s2, s3);
168+
} else {
169+
tile.data[m][k4] = VEC4_T(0.0);
170+
}
171+
}
172+
}
173+
}
174+
175+
void store_output_tile_scalar(
176+
const FPOutTile out_tile,
177+
const int n4_start,
178+
const int m_start,
179+
const int b,
180+
const int N4,
181+
const int N,
182+
const int M) {
183+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
184+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
185+
if (m_start + m < M && n4_start + n4 < N4) {
186+
const int base = (b * M + m_start + m) * N + mul_4(n4_start + n4);
187+
const VEC4_T val = out_tile.data[m][n4];
188+
t_out[base] = val.x;
189+
if (mul_4(n4_start + n4) + 1 < N) t_out[base + 1] = val.y;
190+
if (mul_4(n4_start + n4) + 2 < N) t_out[base + 2] = val.z;
191+
if (mul_4(n4_start + n4) + 3 < N) t_out[base + 3] = val.w;
192+
}
193+
}
194+
}
195+
}
196+
#endif // SCALAR_BUFFER
197+
198+
void main() {
199+
// Thread mapping: X=OC4 (N4), Y=L/tile_m (M tiles), Z=batch
200+
const int tile_idx_n = int(gl_GlobalInvocationID.x);
201+
const int tile_idx_m = int(gl_GlobalInvocationID.y);
202+
203+
const int n4_start = tile_idx_n * TILE_N4;
204+
const int m_start = tile_idx_m * TILE_M;
205+
206+
// in_sizes: {L, C_in, N, 1} in WHCN
207+
const int K = in_sizes.y; // C_in
208+
const int M = in_sizes.x; // L
209+
const int K4 = div_up_4(K);
210+
// out_sizes: {L, C_out, N, 1} in WHCN
211+
const int N_out = out_sizes.y; // C_out
212+
const int N4 = div_up_4(N_out);
213+
214+
if (n4_start >= N4 || m_start >= M) {
215+
return;
216+
}
217+
218+
FPOutTile out_tile;
219+
initialize(out_tile);
220+
221+
FPInputTile in_tile;
222+
FPWeightTile w_tile;
223+
224+
const int b = int(gl_GlobalInvocationID.z);
225+
226+
for (int k4 = 0; k4 < K4; k4++) {
227+
#ifdef SCALAR_BUFFER
228+
load_input_tile_scalar(in_tile, k4, m_start, b, K4, K, M);
229+
#else
230+
load_input_tile_with_checks(in_tile, k4, m_start, b, K4, M);
231+
#endif
232+
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
233+
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
234+
}
235+
236+
#ifdef HAS_BIAS
237+
// Load bias (per output channel) and apply
238+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
239+
VEC4_T bias_val = VEC4_T(0.0);
240+
if (n4_start + n4 < N4) {
241+
#ifdef BIAS_BUFFER
242+
// Bias is a 1D tensor [C_out], width-packed.
243+
// For buffer storage, width-packed has packed_dim_block_size=1, so data
244+
// is stored as contiguous scalars. Read 4 with bounds checking.
245+
const int bias_base = mul_4(n4_start + n4);
246+
T b0 = t_bias[bias_base];
247+
T b1 = (bias_base + 1 < N_out) ? t_bias[bias_base + 1] : T(0);
248+
T b2 = (bias_base + 2 < N_out) ? t_bias[bias_base + 2] : T(0);
249+
T b3 = (bias_base + 3 < N_out) ? t_bias[bias_base + 3] : T(0);
250+
bias_val = VEC4_T(b0, b1, b2, b3);
251+
#else
252+
bias_val = texelFetch(t_bias, ivec3(n4_start + n4, 0, 0), 0);
253+
#endif
254+
}
255+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
256+
out_tile.data[m][n4] = out_tile.data[m][n4] + bias_val;
257+
}
258+
}
259+
#endif
260+
261+
// Apply activation clamp
262+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
263+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
264+
out_tile.data[m][n4] =
265+
clamp(out_tile.data[m][n4], VEC4_T(output_min), VEC4_T(output_max));
266+
}
267+
}
268+
269+
#ifdef SCALAR_BUFFER
270+
store_output_tile_scalar(out_tile, n4_start, m_start, b, N4, N_out, M);
271+
#else
272+
store_output_tile_with_checks(out_tile, n4_start, m_start, b, N4, M);
273+
#endif
274+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv1d_pw:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
WEIGHT_STORAGE: texture2d
12+
HAS_BIAS: false
13+
TILE_M4: 1
14+
TILE_K4: 1
15+
TILE_N4: 1
16+
TILE_M: 4
17+
generate_variant_forall:
18+
combination:
19+
parameter_names: [STORAGE, WEIGHT_STORAGE]
20+
combos:
21+
- parameter_values: [texture3d, texture2d]
22+
- parameter_values: [texture3d, buffer]
23+
- parameter_values: [buffer, texture2d]
24+
- parameter_values: [buffer, buffer]
25+
DTYPE:
26+
- VALUE: float
27+
- VALUE: half
28+
shader_variants:
29+
- NAME: conv1d_pw
30+
- NAME: conv1d_pw_bias
31+
HAS_BIAS: true

0 commit comments

Comments
 (0)