Skip to content

Commit 7a706ae

Browse files
committed
GPUCommonAlgorithm: Use CUB for soring on device instead of Thrust
1 parent 30efe2e commit 7a706ae

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

GPU/Common/GPUCommonAlgorithmThrust.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@
2323
#pragma GCC diagnostic pop
2424

2525
#include "GPUCommonDef.h"
26+
#include "GPUCommonHelpers.h"
2627

27-
#ifdef __CUDACC__
28+
#ifndef __HIPCC__ // CUDA
2829
#define GPUCA_THRUST_NAMESPACE thrust::cuda
29-
#else
30+
#define GPUCA_CUB_NAMESPACE cub
31+
#include <cub/cub.cuh>
32+
#else // HIP
3033
#define GPUCA_THRUST_NAMESPACE thrust::hip
34+
#define GPUCA_CUB_NAMESPACE hipcub
35+
#include <hipcub/hipcub.hpp>
3136
#endif
3237

3338
namespace o2::gpu
@@ -89,11 +94,20 @@ template <class T, class S>
8994
GPUhi() void GPUCommonAlgorithm::sortOnDevice(auto* rec, int32_t stream, T* begin, size_t N, const S& comp)
9095
{
9196
thrust::device_ptr<T> p(begin);
97+
#if 0 // Use Thrust
9298
auto alloc = rec->getThrustVolatileDeviceAllocator();
9399
thrust::sort(GPUCA_THRUST_NAMESPACE::par(alloc).on(rec->mInternals->Streams[stream]), p, p + N, comp);
100+
#else // Use CUB
101+
size_t tempSize = 0;
102+
void* tempMem = nullptr;
103+
GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream]));
104+
tempMem = rec->AllocateVolatileDeviceMemory(tempSize);
105+
GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream]));
106+
#endif
94107
}
95108
} // namespace o2::gpu
96109

97110
#undef GPUCA_THRUST_NAMESPACE
111+
#undef GPUCA_CUB_NAMESPACE
98112

99113
#endif

GPU/GPUTracking/Base/GPUGeneralKernels.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
#endif
2828

2929
#if defined(__HIPCC__)
30-
#define GPUCA_CUB hipcub
30+
#define GPUCA_CUB_NAMESPACE hipcub
3131
#else
32-
#define GPUCA_CUB cub
32+
#define GPUCA_CUB_NAMESPACE cub
3333
#endif
3434

3535
namespace o2::gpu
@@ -54,7 +54,7 @@ class GPUKernelTemplate
5454
struct GPUSharedMemoryWarpScan64 {
5555
// Provides the shared memory resources for warp wide CUB collectives
5656
#if (defined(__CUDACC__) || defined(__HIPCC__)) && defined(GPUCA_GPUCODE) && !defined(GPUCA_GPUCODE_HOSTONLY)
57-
typedef GPUCA_CUB::WarpScan<T> WarpScan;
57+
typedef GPUCA_CUB_NAMESPACE::WarpScan<T> WarpScan;
5858
union {
5959
typename WarpScan::TempStorage cubWarpTmpMem;
6060
};
@@ -65,9 +65,9 @@ class GPUKernelTemplate
6565
struct GPUSharedMemoryScan64 {
6666
// Provides the shared memory resources for CUB collectives
6767
#if (defined(__CUDACC__) || defined(__HIPCC__)) && defined(GPUCA_GPUCODE) && !defined(GPUCA_GPUCODE_HOSTONLY)
68-
typedef GPUCA_CUB::BlockScan<T, I> BlockScan;
69-
typedef GPUCA_CUB::BlockReduce<T, I> BlockReduce;
70-
typedef GPUCA_CUB::WarpScan<T> WarpScan;
68+
typedef GPUCA_CUB_NAMESPACE::BlockScan<T, I> BlockScan;
69+
typedef GPUCA_CUB_NAMESPACE::BlockReduce<T, I> BlockReduce;
70+
typedef GPUCA_CUB_NAMESPACE::WarpScan<T> WarpScan;
7171
union {
7272
typename BlockScan::TempStorage cubTmpMem;
7373
typename BlockReduce::TempStorage cubReduceTmpMem;
@@ -110,6 +110,6 @@ class GPUitoa : public GPUKernelTemplate
110110

111111
} // namespace o2::gpu
112112

113-
#undef GPUCA_CUB
113+
#undef GPUCA_CUB_NAMESPACE
114114

115115
#endif

0 commit comments

Comments
 (0)