Skip to content

Commit 1ee58ed

Browse files
ssjiaSS-JIA
authored andcommitted
[ez][ET-VK][q8ta_conv2d_pw] Halve accumulator to lift Adreno occupancy
Pull Request resolved: pytorch#19396 The pointwise quantized conv shader allocated ivec4 out_accum[4][2] = 32 int32 accumulators per thread, which on Adreno 740 pinned 28 full-precision registers per thread and capped ALU fiber occupancy at 37%. AOC reported 26.7% exposed long-latency stalls, evidence that occupancy was too low to hide texture and SSBO latency. Halve the accumulator to 16 ints by reducing TILE_N4 from 2 to 1 (each thread now covers 4 widths × 4 output channels = a single 4×4 output block). The compensating dispatch change is in pick_q8ta_conv2d_pw_global_wg_size: global_wg.x doubles since each thread covers half as many output channel blocks as before. Each thread still loads 1 input ivec4 (4 widths) per K-iter, preserving the natural int8x4 packing alignment, so arithmetic intensity drops only 25% (2.67 → 2.0 MAC/B, in contrast to the variant where TILE_M is halved which drops AI by 50%). ghstack-source-id: 379519735 @exported-using-ghexport Differential Revision: [D103770023](https://our.internmc.facebook.com/intern/diff/D103770023/)
1 parent 24da2f6 commit 1ee58ed

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,18 @@ $if USE_INT8_DOT_PRODUCT_EXT == 1:
2222

2323
${define_active_storage_type("buffer")}
2424

25+
// Each thread computes a TILE_M (width) x TILE_N (output channel) output block,
26+
// using an int32 accumulator tile.
2527
// corresponds to input/output width dim
2628
#define TILE_M4 1
2729
// corresponds to input channels dim
2830
#define TILE_K4 1
2931
// corresponds to output channels dim
30-
#define TILE_N4 2
32+
#define TILE_N4 1
3133

3234
#define TILE_M 4
3335
#define TILE_K 4
34-
#define TILE_N 8
36+
#define TILE_N 4
3537

3638
layout(std430) buffer;
3739

@@ -86,9 +88,9 @@ int compute_outp_buffer_idx(
8688
}
8789

8890
void main() {
89-
// Thread mapping: each thread handles TILE_M (4) widths × TILE_N (8) output channels
90-
// gl_GlobalInvocationID.x output channel blocks (TILE_N4 = 2 blocks of 4 channels)
91-
// gl_GlobalInvocationID.y width blocks (TILE_M4 = 1 block of 4 widths)
91+
// Thread mapping: each thread handles TILE_M widths x TILE_N output channels.
92+
// gl_GlobalInvocationID.x -> output channel blocks.
93+
// gl_GlobalInvocationID.y -> width blocks.
9294
// gl_GlobalInvocationID.z → batch (or height * batch combined)
9395
const int oc_block_idx = int(gl_GlobalInvocationID.x) * TILE_N4;
9496
const int ow_block_idx = int(gl_GlobalInvocationID.y) * TILE_M4;
@@ -137,11 +139,11 @@ void main() {
137139

138140
// Main accumulation loop over K dimension
139141
for (int k4 = 0; k4 < K4_per_group; k4++) {
140-
// Load packed int8 input tile (TILE_M4=1, TILE_K4=1)
142+
// Load the packed int8 input tile for the current width and K sub-block.
141143
// Each int contains 4 packed int8s (one per width position in the tile)
142144
ivec4 int8_input_tile = t_packed_int8_input[input_idx];
143145

144-
// Load int8 weight tile (TILE_K4=1, TILE_N4=2)
146+
// Load the int8 weight tile for the current K and output-channel sub-block.
145147
ivec4 int8_weight_tile[TILE_N4];
146148
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
147149
int8_weight_tile[n4] = texelFetch(

backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,17 @@ utils::uvec3 pick_q8ta_conv2d_pw_global_wg_size(
3333
const uint32_t H = graph->size_at<uint32_t>(-2, output);
3434
const uint32_t C = graph->size_at<uint32_t>(-3, output);
3535

36-
// The 4W4C shader processes tiles of:
37-
// - TILE_N4=2 groups of 4 output channels (8 channels per thread)
38-
// - TILE_M4=1 groups of 4 widths (4 widths per thread)
39-
// - 1 height per thread
40-
constexpr uint32_t TILE_N4 = 2;
36+
// Each thread covers a 4-width x 4-channel output block.
37+
// Tile constants must match TILE_M4 / TILE_N4 in q8ta_conv2d_pw.glsl.
38+
constexpr uint32_t TILE_N4 = 1;
4139
constexpr uint32_t TILE_M4 = 1;
4240

4341
const uint32_t C4 = utils::div_up_4(C);
4442
const uint32_t W4 = utils::div_up_4(W);
4543

4644
// Global workgroup size:
47-
// x = output channels / (TILE_N4 * 4) = C4 / TILE_N4
48-
// y = width / (TILE_M4 * 4) = W4 / TILE_M4
45+
// x = output channels / (TILE_N4 * 4) = C4 / TILE_N4 = C4
46+
// y = width / (TILE_M4 * 4) = W4 / TILE_M4 = W4
4947
// z = height
5048
return {utils::div_up(C4, TILE_N4), utils::div_up(W4, TILE_M4), H};
5149
}

0 commit comments

Comments
 (0)