@@ -20,94 +20,49 @@ ${define_active_storage_type(STORAGE)}
2020
2121layout (std430) buffer ;
2222
23+ #include "indexing.glslh"
24+
2325${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
2426${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
2527
26- layout (push_constant) uniform restrict Block {
27- ivec4 tin_sizes;
28- ivec3 tout_limits;
29- };
28+ ${layout_declare_ubo(B, "TextureMetadata", "in_meta")}
29+ ${layout_declare_ubo(B, "TextureMetadata", "out_meta")}
3030
3131layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3232
33- layout (constant_id = 3 ) const int packed_dim = 0 ;
34- layout (constant_id = 4 ) const int reduce_dim = 0 ;
35- layout (constant_id = 5 ) const int group_dim = 1 ;
33+ ${layout_declare_spec_const(C, "int ", "out_layout", "CONTIG_LAYOUT_INT")}
34+ const int packed_dim = get_packed_dim(out_layout);
3635
37- // A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
38- // threads that will co-operate to compute one reduction output. There may be
39- // multiple groups computing distinct reduction outputs within one work group.
40- #define NWORKERS 4
36+ ${layout_declare_spec_const(C, "int ", "reduce_dim", "0 ")}
37+ ${layout_declare_spec_const(C, "int ", "group_dim", "1 ")}
4138
42- // Sets an upper limit on the total size of a work group based on how many
43- // elements are allocated in the shared memory array below. Each thread in the
44- // work group will write into its assigned element in the shared array.
39+ #define NWORKERS 4
4540#define MAX_NTHREADS 16
4641
4742shared vec4 shared_max[MAX_NTHREADS];
4843shared vec4 shared_sum[MAX_NTHREADS];
4944
50- #include "indexing_utils.h"
51-
5245int tid_to_smi(const ivec2 tid) {
5346 return tid.x + tid.y * NWORKERS;
5447}
5548
56- /*
57- * The shaders below compute softmax for a tensor. Softmax is an interesting mix
58- * between a reduction operator and a unary elementwise operator, defined as
59- * exp(x) / (sum of exp(x)). The general flow of the computation is:
60- *
61- * First, find the maximum element along the reduction dim. The maximum element
62- * is used to preserve numerical stability, since division of exponents is
63- * translation invariant.
64- *
65- * Next, compute the sum of exp(x - max_element) along the reduction dim.
66- *
67- * Finally, for each element along the reduction dim, we compute the output as
68- * exp(x - max_element) / sum_of_exponents.
69- *
70- * The shaders below also utilize shared memory to have multiple threads help
71- * compute the max and sum reduction operations. A total of NGROUPS x NWORKERS
72- * threads are launched. Each group works on a unique reduction "row", and
73- * within a group NWORKERS threads co-operate to compute the max and sum of one
74- * "row". Each worker in the group is responsible for computing a partial output
75- * of the "row" and uploading it to shared memory; the overall reduction output
76- * can then be determined by aggregating the partial outputs stored in shared
77- * memory.
78- *
79- * As a caveat, this shader does not currently support cases where `batch` > 1
80- * and the reduce dim happens to also be the batch concatenation dim. To support
81- * this, there will need to be additional logic to set the starting value of
82- * `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
83- * supporting this case is left as an exercise for when it is required.
84- *
85- * As a final note, log softmax is supported with this shader as well since via
86- * the op1 and op2 macro definitions. See the corresponding YAML file for more
87- * details.
88- */
89-
9049/*
9150 * Computes softmax where the reduction dim is orthogonal to the packed dim.
9251 * This case is simpler because each element of a texel belongs to a separate
9352 * reduction dim, meaning we don't have to perform reduction along a texel.
9453 */
9554void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
96- // shared memory index of this thread
9755 const int smi = tid_to_smi(tid);
98- // used to iterate over all shared memory in the group
9956 int group_i;
10057
10158 scan_pos[reduce_dim] = tid.x;
102- vec4 max_elements = load_texel(tin, scan_pos);
103- // This thread computes a partial maximum
104- for (int i = tid.x; i < tin_sizes[reduce_dim];
59+ vec4 max_elements = texelFetch(tin, scan_pos, 0 );
60+ for (int i = tid.x; i < in_meta.sizes[reduce_dim];
10561 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
106- max_elements = max (max_elements, load_texel (tin, scan_pos));
62+ max_elements = max (max_elements, texelFetch (tin, scan_pos, 0 ));
10763 }
10864 shared_max[smi] = max_elements;
10965 barrier();
110- // Iterate over the partial maximums to obtain the overall maximum
11166 group_i = tid.y * NWORKERS;
11267 max_elements = shared_max[group_i++ ];
11368 for (int i = 1 ; i < NWORKERS; ++ i, group_i++ ) {
@@ -116,63 +71,44 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
11671
11772 scan_pos[reduce_dim] = tid.x;
11873 vec4 denominators = vec4 (0 );
119- // Compute partial sum
120- for (int i = tid.x; i < tin_sizes[reduce_dim];
74+ for (int i = tid.x; i < in_meta.sizes[reduce_dim];
12175 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
122- denominators += exp (load_texel (tin, scan_pos) - max_elements);
76+ denominators += exp (texelFetch (tin, scan_pos, 0 ) - max_elements);
12377 }
12478 shared_sum[smi] = denominators;
12579 barrier();
126- // Iterate over the partial sums to obtain the overall sum
12780 group_i = tid.y * NWORKERS;
12881 denominators = shared_sum[group_i++ ];
12982 for (int i = 1 ; i < NWORKERS; ++ i, group_i++ ) {
13083 denominators += shared_sum[group_i];
13184 }
13285
133- // Determine if there are any padding elements in the final texel of the
134- // packed dimension
135- const int nspill = mod4(tin_sizes[packed_dim]);
136- // Detect if this thread is working on the final texels of the packed
137- // dimension, which may have padding elements
86+ const int nspill = mod_4(in_meta.sizes[packed_dim]);
13887 const bool is_last_texel =
139- scan_pos[packed_dim] == (tout_limits [packed_dim] - 1 );
88+ scan_pos[packed_dim] == (out_meta.limits [packed_dim] - 1 );
14089
14190 scan_pos[reduce_dim] = tid.x;
142- for (int i = tid.x; i < tin_sizes [reduce_dim];
91+ for (int i = tid.x; i < in_meta.sizes [reduce_dim];
14392 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
144- const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
145- // Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
93+ const vec4 numerators = op1(texelFetch(tin, scan_pos, 0 ) - max_elements);
14694 const vec4 safe_denom = max (denominators, vec4 (1e-37 ));
14795 vec4 outtex = op2(numerators, safe_denom);
148- // Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation.
149- // This avoids isnan()/x!=x which may not work reliably on all GPU drivers:
150- // - OpIsNan may have driver bugs for certain NaN bit patterns
151- // - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics)
152- // NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000
15396 {
15497 uvec4 bits = floatBitsToUint(outtex);
155- // Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0
15698 uvec4 nan_inf_mask = uvec4 (
15799 ((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
158100 ((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
159101 ((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
160102 ((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u);
161- // Zero out bits where NaN/Inf: normal values are unchanged
162103 outtex = uintBitsToFloat(bits & ~ nan_inf_mask);
163104 }
164- // For the last texel in the packed dim, make sure that the padding elements
165- // are explicitly set to 0. Otherwise, they may influence computations later
166- // down the line.
167105 if (is_last_texel && nspill > 0 ) {
168106 [[unroll]] for (int i = nspill; i < 4 ; ++ i) {
169107 outtex[i] = 0 ;
170108 }
171109 }
172- write_texel (tout, scan_pos, outtex);
110+ imageStore (tout, scan_pos, outtex);
173111 }
174- // Flush outstanding imageStore writes so they're committed to memory and
175- // visible to subsequent GPU operations on this image.
176112 memoryBarrierImage();
177113}
178114
@@ -185,44 +121,31 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
185121 * multiple of 4) so that they do not influence the output of reduction.
186122 */
187123void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
188- // shared memory index of this thread
189124 const int smi = tid_to_smi(tid);
190- // used to iterate over all shared memory in the group
191125 int group_i;
192126
193- const int nspill = mod4(tin_sizes [packed_dim]);
194- const int reduce_len = tin_sizes [packed_dim] - nspill;
127+ const int nspill = mod_4(in_meta.sizes [packed_dim]);
128+ const int reduce_len = in_meta.sizes [packed_dim] - nspill;
195129
196130 scan_pos[reduce_dim] = tid.x;
197- // Initialize with -FLT_MAX to avoid contaminating the maximum with out-of-
198- // bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12
199- // has 3 texels but NWORKERS=4), worker threads with no valid texels would
200- // otherwise load from an OOB index and get 0, which corrupts the max for
201- // rows where all values are negative and causes denominator underflow -> NaN.
202131 vec4 max_elements = vec4 (- 3.402823e+38 );
203132 for (int i = tid.x * 4 ; i < reduce_len;
204133 i += NWORKERS * 4 , scan_pos[reduce_dim] += NWORKERS) {
205- max_elements = max (max_elements, load_texel (tin, scan_pos));
134+ max_elements = max (max_elements, texelFetch (tin, scan_pos, 0 ));
206135 }
207- // For the last texel in the dim, if there are padding elements then each
208- // element of the texel needs to be processed individually such that the
209- // padding elements are ignored
210- if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0 ) {
211- const vec4 intex = load_texel(tin, scan_pos);
136+ if (scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1 && nspill > 0 ) {
137+ const vec4 intex = texelFetch(tin, scan_pos, 0 );
212138 for (int i = 0 ; i < nspill; ++ i) {
213139 max_elements.x = max (intex[i], max_elements.x);
214140 }
215141 }
216142 shared_max[smi] = max_elements;
217143 barrier();
218- // Iterate over the partial maximums to obtain the overall maximum
219144 group_i = tid.y * NWORKERS;
220145 max_elements = shared_max[group_i++ ];
221146 for (int i = 1 ; i < NWORKERS; ++ i, group_i++ ) {
222147 max_elements = max (max_elements, shared_max[group_i]);
223148 }
224- // Each element of the texel is itself a partial maximum; iterate over the
225- // texel to find the actual maximum
226149 float max_element = max_elements.x;
227150 [[unroll]] for (int i = 1 ; i < 4 ; ++ i) {
228151 max_element = max (max_elements[i], max_element);
@@ -232,49 +155,40 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
232155 vec4 denominators = vec4 (0 );
233156 for (int i = tid.x * 4 ; i < reduce_len;
234157 i += NWORKERS * 4 , scan_pos[reduce_dim] += NWORKERS) {
235- denominators += exp (load_texel (tin, scan_pos) - max_element);
158+ denominators += exp (texelFetch (tin, scan_pos, 0 ) - max_element);
236159 }
237- // For the last texel in the dim, if there are padding elements then each
238- // element of the texel needs to be processed individually such that the
239- // padding elements are ignored
240- if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 ) {
241- const vec4 intex = load_texel(tin, scan_pos);
160+ if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1 ) {
161+ const vec4 intex = texelFetch(tin, scan_pos, 0 );
242162 for (int i = 0 ; i < nspill; ++ i) {
243163 denominators.x += exp (intex[i] - max_element);
244164 }
245165 }
246166 shared_sum[smi] = denominators;
247167 barrier();
248- // Iterate over the partial sums to obtain the overall sum
249168 group_i = tid.y * NWORKERS;
250169 denominators = shared_sum[group_i++ ];
251170 for (int i = 1 ; i < NWORKERS; ++ i, group_i++ ) {
252171 denominators += shared_sum[group_i];
253172 }
254- // Reduce over the accumulated texel to find the overall sum
255173 float denominator = 0 ;
256174 [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
257175 denominator += denominators[i];
258176 }
259- // Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
260177 const float safe_denominator = max (denominator, 1e-37 );
261178
262179 scan_pos[reduce_dim] = tid.x;
263180 for (int i = tid.x * 4 ; i < reduce_len;
264181 i += NWORKERS * 4 , scan_pos[reduce_dim] += NWORKERS) {
265- const vec4 numerators = op1(load_texel (tin, scan_pos) - max_element);
266- write_texel (tout, scan_pos, op2(numerators, safe_denominator));
182+ const vec4 numerators = op1(texelFetch (tin, scan_pos, 0 ) - max_element);
183+ imageStore (tout, scan_pos, op2(numerators, safe_denominator));
267184 }
268- // For the last texel in the dim, if there are padding elements then the
269- // padding elements need to be set to 0 explicitly, otherwise they may
270- // influence subsequent operations.
271- if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 ) {
272- const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
185+ if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1 ) {
186+ const vec4 numerator = op1(texelFetch(tin, scan_pos, 0 ) - max_element);
273187 vec4 outtex = op2(numerator, safe_denominator);
274188 [[unroll]] for (int i = nspill; i < 4 ; ++ i) {
275189 outtex[i] = 0 ;
276190 }
277- write_texel (tout, scan_pos, outtex);
191+ imageStore (tout, scan_pos, outtex);
278192 }
279193}
280194
@@ -286,7 +200,7 @@ void main() {
286200 gl_LocalInvocationID[reduce_dim],
287201 gl_LocalInvocationID[group_dim]);
288202
289- if (any (greaterThanEqual (scan_pos, tout_limits ))) {
203+ if (any (greaterThanEqual (scan_pos, out_meta.limits ))) {
290204 return ;
291205 }
292206
0 commit comments