Skip to content

Commit a9b427c

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Modernize softmax and log_softmax
Modernize softmax and log_softmax to support ANY_STORAGE. Migrate both buffer and texture shaders from indexing_utils.h to indexing.glslh with BufferMetadata/TextureMetadata UBOs. Merge separate texture and buffer dispatch functions into a unified add_softmax_node using add_storage_type_suffix and graph.meta_ubo(). Pull Request resolved: pytorch#18054 ghstack-source-id: 353546688 @exported-using-ghexport Differential Revision: [D95970171](https://our.internmc.facebook.com/intern/diff/D95970171/)
1 parent 2349049 commit a9b427c

7 files changed

Lines changed: 241 additions & 157 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
327327
)
328328
def register_softmax_cpp_ops():
329329
return OpFeatures(
330-
inputs_storage=utils.ANY_TEXTURE,
330+
inputs_storage=utils.ANY_STORAGE,
331331
inputs_dtypes=utils.FP_T,
332332
supports_resize=True,
333333
)

backends/vulkan/runtime/graph/ops/glsl/softmax.glsl

Lines changed: 33 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -20,94 +20,49 @@ ${define_active_storage_type(STORAGE)}
2020

2121
layout(std430) buffer;
2222

23+
#include "indexing.glslh"
24+
2325
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
2426
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
2527

26-
layout(push_constant) uniform restrict Block {
27-
ivec4 tin_sizes;
28-
ivec3 tout_limits;
29-
};
28+
${layout_declare_ubo(B, "TextureMetadata", "in_meta")}
29+
${layout_declare_ubo(B, "TextureMetadata", "out_meta")}
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232

33-
layout(constant_id = 3) const int packed_dim = 0;
34-
layout(constant_id = 4) const int reduce_dim = 0;
35-
layout(constant_id = 5) const int group_dim = 1;
33+
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
34+
const int packed_dim = get_packed_dim(out_layout);
3635

37-
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
38-
// threads that will co-operate to compute one reduction output. There may be
39-
// multiple groups computing distinct reduction outputs within one work group.
40-
#define NWORKERS 4
36+
${layout_declare_spec_const(C, "int", "reduce_dim", "0")}
37+
${layout_declare_spec_const(C, "int", "group_dim", "1")}
4138

42-
// Sets an upper limit on the total size of a work group based on how many
43-
// elements are allocated in the shared memory array below. Each thread in the
44-
// work group will write into its assigned element in the shared array.
39+
#define NWORKERS 4
4540
#define MAX_NTHREADS 16
4641

4742
shared vec4 shared_max[MAX_NTHREADS];
4843
shared vec4 shared_sum[MAX_NTHREADS];
4944

50-
#include "indexing_utils.h"
51-
5245
int tid_to_smi(const ivec2 tid) {
5346
return tid.x + tid.y * NWORKERS;
5447
}
5548

56-
/*
57-
* The shaders below compute softmax for a tensor. Softmax is an interesting mix
58-
* between a reduction operator and a unary elementwise operator, defined as
59-
* exp(x) / (sum of exp(x)). The general flow of the computation is:
60-
*
61-
* First, find the maximum element along the reduction dim. The maximum element
62-
* is used to preserve numerical stability, since division of exponents is
63-
* translation invariant.
64-
*
65-
* Next, compute the sum of exp(x - max_element) along the reduction dim.
66-
*
67-
* Finally, for each element along the reduction dim, we compute the output as
68-
* exp(x - max_element) / sum_of_exponents.
69-
*
70-
* The shaders below also utilize shared memory to have multiple threads help
71-
* compute the max and sum reduction operations. A total of NGROUPS x NWORKERS
72-
* threads are launched. Each group works on a unique reduction "row", and
73-
* within a group NWORKERS threads co-operate to compute the max and sum of one
74-
* "row". Each worker in the group is responsible for computing a partial output
75-
* of the "row" and uploading it to shared memory; the overall reduction output
76-
* can then be determined by aggregating the partial outputs stored in shared
77-
* memory.
78-
*
79-
* As a caveat, this shader does not currently support cases where `batch` > 1
80-
* and the reduce dim happens to also be the batch concatenation dim. To support
81-
* this, there will need to be additional logic to set the starting value of
82-
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
83-
* supporting this case is left as an exercise for when it is required.
84-
*
85-
* As a final note, log softmax is supported with this shader as well since via
86-
* the op1 and op2 macro definitions. See the corresponding YAML file for more
87-
* details.
88-
*/
89-
9049
/*
9150
* Computes softmax where the reduction dim is orthogonal to the packed dim.
9251
* This case is simpler because each element of a texel belongs to a separate
9352
* reduction dim, meaning we don't have to perform reduction along a texel.
9453
*/
9554
void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
96-
// shared memory index of this thread
9755
const int smi = tid_to_smi(tid);
98-
// used to iterate over all shared memory in the group
9956
int group_i;
10057

10158
scan_pos[reduce_dim] = tid.x;
102-
vec4 max_elements = load_texel(tin, scan_pos);
103-
// This thread computes a partial maximum
104-
for (int i = tid.x; i < tin_sizes[reduce_dim];
59+
vec4 max_elements = texelFetch(tin, scan_pos, 0);
60+
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
10561
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
106-
max_elements = max(max_elements, load_texel(tin, scan_pos));
62+
max_elements = max(max_elements, texelFetch(tin, scan_pos, 0));
10763
}
10864
shared_max[smi] = max_elements;
10965
barrier();
110-
// Iterate over the partial maximums to obtain the overall maximum
11166
group_i = tid.y * NWORKERS;
11267
max_elements = shared_max[group_i++];
11368
for (int i = 1; i < NWORKERS; ++i, group_i++) {
@@ -116,63 +71,44 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
11671

11772
scan_pos[reduce_dim] = tid.x;
11873
vec4 denominators = vec4(0);
119-
// Compute partial sum
120-
for (int i = tid.x; i < tin_sizes[reduce_dim];
74+
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
12175
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
122-
denominators += exp(load_texel(tin, scan_pos) - max_elements);
76+
denominators += exp(texelFetch(tin, scan_pos, 0) - max_elements);
12377
}
12478
shared_sum[smi] = denominators;
12579
barrier();
126-
// Iterate over the partial sums to obtain the overall sum
12780
group_i = tid.y * NWORKERS;
12881
denominators = shared_sum[group_i++];
12982
for (int i = 1; i < NWORKERS; ++i, group_i++) {
13083
denominators += shared_sum[group_i];
13184
}
13285

133-
// Determine if there are any padding elements in the final texel of the
134-
// packed dimension
135-
const int nspill = mod4(tin_sizes[packed_dim]);
136-
// Detect if this thread is working on the final texels of the packed
137-
// dimension, which may have padding elements
86+
const int nspill = mod_4(in_meta.sizes[packed_dim]);
13887
const bool is_last_texel =
139-
scan_pos[packed_dim] == (tout_limits[packed_dim] - 1);
88+
scan_pos[packed_dim] == (out_meta.limits[packed_dim] - 1);
14089

14190
scan_pos[reduce_dim] = tid.x;
142-
for (int i = tid.x; i < tin_sizes[reduce_dim];
91+
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
14392
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
144-
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
145-
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
93+
const vec4 numerators = op1(texelFetch(tin, scan_pos, 0) - max_elements);
14694
const vec4 safe_denom = max(denominators, vec4(1e-37));
14795
vec4 outtex = op2(numerators, safe_denom);
148-
// Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation.
149-
// This avoids isnan()/x!=x which may not work reliably on all GPU drivers:
150-
// - OpIsNan may have driver bugs for certain NaN bit patterns
151-
// - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics)
152-
// NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000
15396
{
15497
uvec4 bits = floatBitsToUint(outtex);
155-
// Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0
15698
uvec4 nan_inf_mask = uvec4(
15799
((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
158100
((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
159101
((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
160102
((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u);
161-
// Zero out bits where NaN/Inf: normal values are unchanged
162103
outtex = uintBitsToFloat(bits & ~nan_inf_mask);
163104
}
164-
// For the last texel in the packed dim, make sure that the padding elements
165-
// are explicitly set to 0. Otherwise, they may influence computations later
166-
// down the line.
167105
if (is_last_texel && nspill > 0) {
168106
[[unroll]] for (int i = nspill; i < 4; ++i) {
169107
outtex[i] = 0;
170108
}
171109
}
172-
write_texel(tout, scan_pos, outtex);
110+
imageStore(tout, scan_pos, outtex);
173111
}
174-
// Flush outstanding imageStore writes so they're committed to memory and
175-
// visible to subsequent GPU operations on this image.
176112
memoryBarrierImage();
177113
}
178114

@@ -185,44 +121,31 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
185121
* multiple of 4) so that they do not influence the output of reduction.
186122
*/
187123
void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
188-
// shared memory index of this thread
189124
const int smi = tid_to_smi(tid);
190-
// used to iterate over all shared memory in the group
191125
int group_i;
192126

193-
const int nspill = mod4(tin_sizes[packed_dim]);
194-
const int reduce_len = tin_sizes[packed_dim] - nspill;
127+
const int nspill = mod_4(in_meta.sizes[packed_dim]);
128+
const int reduce_len = in_meta.sizes[packed_dim] - nspill;
195129

196130
scan_pos[reduce_dim] = tid.x;
197-
// Initialize with -FLT_MAX to avoid contaminating the maximum with out-of-
198-
// bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12
199-
// has 3 texels but NWORKERS=4), worker threads with no valid texels would
200-
// otherwise load from an OOB index and get 0, which corrupts the max for
201-
// rows where all values are negative and causes denominator underflow -> NaN.
202131
vec4 max_elements = vec4(-3.402823e+38);
203132
for (int i = tid.x * 4; i < reduce_len;
204133
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
205-
max_elements = max(max_elements, load_texel(tin, scan_pos));
134+
max_elements = max(max_elements, texelFetch(tin, scan_pos, 0));
206135
}
207-
// For the last texel in the dim, if there are padding elements then each
208-
// element of the texel needs to be processed individually such that the
209-
// padding elements are ignored
210-
if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0) {
211-
const vec4 intex = load_texel(tin, scan_pos);
136+
if (scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1 && nspill > 0) {
137+
const vec4 intex = texelFetch(tin, scan_pos, 0);
212138
for (int i = 0; i < nspill; ++i) {
213139
max_elements.x = max(intex[i], max_elements.x);
214140
}
215141
}
216142
shared_max[smi] = max_elements;
217143
barrier();
218-
// Iterate over the partial maximums to obtain the overall maximum
219144
group_i = tid.y * NWORKERS;
220145
max_elements = shared_max[group_i++];
221146
for (int i = 1; i < NWORKERS; ++i, group_i++) {
222147
max_elements = max(max_elements, shared_max[group_i]);
223148
}
224-
// Each element of the texel is itself a partial maximum; iterate over the
225-
// texel to find the actual maximum
226149
float max_element = max_elements.x;
227150
[[unroll]] for (int i = 1; i < 4; ++i) {
228151
max_element = max(max_elements[i], max_element);
@@ -232,49 +155,40 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
232155
vec4 denominators = vec4(0);
233156
for (int i = tid.x * 4; i < reduce_len;
234157
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
235-
denominators += exp(load_texel(tin, scan_pos) - max_element);
158+
denominators += exp(texelFetch(tin, scan_pos, 0) - max_element);
236159
}
237-
// For the last texel in the dim, if there are padding elements then each
238-
// element of the texel needs to be processed individually such that the
239-
// padding elements are ignored
240-
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
241-
const vec4 intex = load_texel(tin, scan_pos);
160+
if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1) {
161+
const vec4 intex = texelFetch(tin, scan_pos, 0);
242162
for (int i = 0; i < nspill; ++i) {
243163
denominators.x += exp(intex[i] - max_element);
244164
}
245165
}
246166
shared_sum[smi] = denominators;
247167
barrier();
248-
// Iterate over the partial sums to obtain the overall sum
249168
group_i = tid.y * NWORKERS;
250169
denominators = shared_sum[group_i++];
251170
for (int i = 1; i < NWORKERS; ++i, group_i++) {
252171
denominators += shared_sum[group_i];
253172
}
254-
// Reduce over the accumulated texel to find the overall sum
255173
float denominator = 0;
256174
[[unroll]] for (int i = 0; i < 4; ++i) {
257175
denominator += denominators[i];
258176
}
259-
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
260177
const float safe_denominator = max(denominator, 1e-37);
261178

262179
scan_pos[reduce_dim] = tid.x;
263180
for (int i = tid.x * 4; i < reduce_len;
264181
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
265-
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
266-
write_texel(tout, scan_pos, op2(numerators, safe_denominator));
182+
const vec4 numerators = op1(texelFetch(tin, scan_pos, 0) - max_element);
183+
imageStore(tout, scan_pos, op2(numerators, safe_denominator));
267184
}
268-
// For the last texel in the dim, if there are padding elements then the
269-
// padding elements need to be set to 0 explicitly, otherwise they may
270-
// influence subsequent operations.
271-
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
272-
const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
185+
if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1) {
186+
const vec4 numerator = op1(texelFetch(tin, scan_pos, 0) - max_element);
273187
vec4 outtex = op2(numerator, safe_denominator);
274188
[[unroll]] for (int i = nspill; i < 4; ++i) {
275189
outtex[i] = 0;
276190
}
277-
write_texel(tout, scan_pos, outtex);
191+
imageStore(tout, scan_pos, outtex);
278192
}
279193
}
280194

@@ -286,7 +200,7 @@ void main() {
286200
gl_LocalInvocationID[reduce_dim],
287201
gl_LocalInvocationID[group_dim]);
288202

289-
if (any(greaterThanEqual(scan_pos, tout_limits))) {
203+
if (any(greaterThanEqual(scan_pos, out_meta.limits))) {
290204
return;
291205
}
292206

backends/vulkan/runtime/graph/ops/glsl/softmax.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ softmax:
1515
- VALUE: half
1616
- VALUE: float
1717
shader_variants:
18-
- NAME: softmax
19-
- NAME: log_softmax
18+
- NAME: softmax_texture3d
19+
- NAME: log_softmax_texture3d
2020
OPERATOR1: X
2121
OPERATOR2: X - log(Y)

0 commit comments

Comments
 (0)