Skip to content

Commit 6665ddd

Browse files
Felix Schlepperf3sch
authored andcommitted
ITS: fix TypedAllocator for cuda thrust
Signed-off-by: Felix Schlepper <felix.schlepper@cern.ch>
1 parent c26672b commit 6665ddd

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackingKernels.cu

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,30 +58,43 @@ namespace gpu
5858
{
5959

6060
template <typename T>
61-
class TypedAllocator : public thrust::device_allocator<T>
62-
{
63-
public:
61+
struct TypedAllocator {
6462
using value_type = T;
65-
using pointer = T*;
63+
using pointer = thrust::device_ptr<T>;
64+
using const_pointer = thrust::device_ptr<const T>;
65+
using size_type = std::size_t;
66+
using difference_type = std::ptrdiff_t;
67+
68+
TypedAllocator() noexcept : mInternalAllocator(nullptr) {}
69+
explicit TypedAllocator(ExternalAllocator* a) noexcept : mInternalAllocator(a) {}
6670

6771
template <typename U>
68-
struct rebind {
69-
using other = TypedAllocator<U>;
70-
};
72+
TypedAllocator(const TypedAllocator<U>& o) noexcept : mInternalAllocator(o.mInternalAllocator)
73+
{
74+
}
7175

72-
explicit TypedAllocator(ExternalAllocator* allocPtr)
73-
: mInternalAllocator(allocPtr) {}
76+
pointer allocate(size_type n)
77+
{
78+
void* raw = mInternalAllocator->allocate(n * sizeof(T));
79+
return thrust::device_pointer_cast(static_cast<T*>(raw));
80+
}
7481

75-
T* allocate(size_t n)
82+
void deallocate(pointer p, size_type n) noexcept
7683
{
77-
return reinterpret_cast<T*>(mInternalAllocator->allocate(n * sizeof(T)));
84+
if (!p) {
85+
return;
86+
}
87+
void* raw = thrust::raw_pointer_cast(p);
88+
mInternalAllocator->deallocate(static_cast<char*>(raw), n * sizeof(T));
7889
}
7990

80-
void deallocate(T* p, size_t n)
91+
bool operator==(TypedAllocator const& o) const noexcept
92+
{
93+
return mInternalAllocator == o.mInternalAllocator;
94+
}
95+
bool operator!=(TypedAllocator const& o) const noexcept
8196
{
82-
char* raw_ptr = reinterpret_cast<char*>(p);
83-
size_t bytes = n * sizeof(T);
84-
mInternalAllocator->deallocate(raw_ptr, bytes); // redundant as internal dealloc is no-op.
97+
return !(*this == o);
8598
}
8699

87100
private:

0 commit comments

Comments
 (0)