Skip to content

Commit 4c0671a

Browse files
committed
GPU OpenCL: subgroup functions not defined for int8
1 parent cb97e7a commit 4c0671a

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

GPU/Common/GPUCommonAlgorithm.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,29 @@ GPUdi() void GPUCommonAlgorithm::swap(T& a, T& b)
338338
// Nothing to do, work_group functions available
339339
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
340340

341-
#define warp_scan_inclusive_add(v) sub_group_scan_inclusive_add(v)
342-
#define warp_broadcast(v, i) sub_group_broadcast(v, i)
341+
template <class T>
342+
GPUdi() T work_group_scan_inclusive_add_FUNC(T v)
343+
{
344+
return sub_group_scan_inclusive_add(v);
345+
}
346+
template <> // FIXME: It seems OpenCL does not support 8 and 16 bit subgroup operations
347+
GPUdi() uint8_t work_group_scan_inclusive_add_FUNC<uint8_t>(uint8_t v)
348+
{
349+
return sub_group_scan_inclusive_add((uint32_t)v);
350+
}
351+
template <class T>
352+
GPUdi() T work_group_broadcast_FUNC(T v, int32_t i)
353+
{
354+
return sub_group_broadcast(v, i);
355+
}
356+
template <>
357+
GPUdi() uint8_t work_group_broadcast_FUNC<uint8_t>(uint8_t v, int32_t i)
358+
{
359+
return sub_group_broadcast((uint32_t)v, i);
360+
}
361+
362+
#define warp_scan_inclusive_add(v) work_group_scan_inclusive_add_FUNC(v)
363+
#define warp_broadcast(v, i) work_group_broadcast_FUNC(v, i)
343364

344365
#elif (defined(__CUDACC__) || defined(__HIPCC__))
345366
// CUDA and HIP work the same way using cub, need just different header

0 commit comments

Comments
 (0)