Skip to content

Commit 09b7486

Browse files
committed
GPUCommonAlgorithm: Use CUB for soring on device instead of Thrust
1 parent 3bc558b commit 09b7486

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

GPU/Common/GPUCommonAlgorithmThrust.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@
2323
#pragma GCC diagnostic pop
2424

2525
#include "GPUCommonDef.h"
26+
#include "GPUCommonHelpers.h"
2627

2728
#ifdef __CUDACC__
2829
#define GPUCA_THRUST_NAMESPACE thrust::cuda
30+
#define GPUCA_CUB_NAMESPACE cub
2931
#else
3032
#define GPUCA_THRUST_NAMESPACE thrust::hip
33+
#define GPUCA_CUB_NAMESPACE hipcub
3134
#endif
3235

3336
namespace o2::gpu
@@ -89,11 +92,20 @@ template <class T, class S>
8992
GPUhi() void GPUCommonAlgorithm::sortOnDevice(auto* rec, int32_t stream, T* begin, size_t N, const S& comp)
9093
{
9194
thrust::device_ptr<T> p(begin);
95+
#if 0 // Use Thrust
9296
auto alloc = rec->getThrustVolatileDeviceAllocator();
9397
thrust::sort(GPUCA_THRUST_NAMESPACE::par(alloc).on(rec->mInternals->Streams[stream]), p, p + N, comp);
98+
#else // Use CUB
99+
size_t tempSize = 0;
100+
void* tempMem = nullptr;
101+
GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream]));
102+
tempMem = rec->AllocateVolatileDeviceMemory(tempSize);
103+
GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream]));
104+
#endif
94105
}
95106
} // namespace o2::gpu
96107

97108
#undef GPUCA_THRUST_NAMESPACE
109+
#undef GPUCA_CUB_NAMESPACE
98110

99111
#endif

GPU/GPUTracking/Base/GPUGeneralKernels.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
#endif
2828

2929
#if defined(__HIPCC__)
30-
#define GPUCA_CUB hipcub
30+
#define GPUCA_CUB_NAMESPACE hipcub
3131
#else
32-
#define GPUCA_CUB cub
32+
#define GPUCA_CUB_NAMESPACE cub
3333
#endif
3434

3535
namespace o2::gpu
@@ -54,7 +54,7 @@ class GPUKernelTemplate
5454
struct GPUSharedMemoryWarpScan64 {
5555
// Provides the shared memory resources for warp wide CUB collectives
5656
#if (defined(__CUDACC__) || defined(__HIPCC__)) && defined(GPUCA_GPUCODE) && !defined(GPUCA_GPUCODE_HOSTONLY)
57-
typedef GPUCA_CUB::WarpScan<T> WarpScan;
57+
typedef GPUCA_CUB_NAMESPACE::WarpScan<T> WarpScan;
5858
union {
5959
typename WarpScan::TempStorage cubWarpTmpMem;
6060
};
@@ -65,9 +65,9 @@ class GPUKernelTemplate
6565
struct GPUSharedMemoryScan64 {
6666
// Provides the shared memory resources for CUB collectives
6767
#if (defined(__CUDACC__) || defined(__HIPCC__)) && defined(GPUCA_GPUCODE) && !defined(GPUCA_GPUCODE_HOSTONLY)
68-
typedef GPUCA_CUB::BlockScan<T, I> BlockScan;
69-
typedef GPUCA_CUB::BlockReduce<T, I> BlockReduce;
70-
typedef GPUCA_CUB::WarpScan<T> WarpScan;
68+
typedef GPUCA_CUB_NAMESPACE::BlockScan<T, I> BlockScan;
69+
typedef GPUCA_CUB_NAMESPACE::BlockReduce<T, I> BlockReduce;
70+
typedef GPUCA_CUB_NAMESPACE::WarpScan<T> WarpScan;
7171
union {
7272
typename BlockScan::TempStorage cubTmpMem;
7373
typename BlockReduce::TempStorage cubReduceTmpMem;
@@ -110,6 +110,6 @@ class GPUitoa : public GPUKernelTemplate
110110

111111
} // namespace o2::gpu
112112

113-
#undef GPUCA_CUB
113+
#undef GPUCA_CUB_NAMESPACE
114114

115115
#endif

0 commit comments

Comments
 (0)