@@ -871,37 +871,53 @@ GPUg() void printCellSeeds(CellSeed* seed, int nCells, const unsigned int tId =
871871}
872872
873873template <typename T>
874- GPUhi () void cubExclusiveScanInPlace (T* in_out, int num_items, cudaStream_t stream = nullptr )
874+ GPUhi () void cubExclusiveScanInPlace (T* in_out, int num_items, cudaStream_t stream = nullptr , ExternalAllocator* alloc = nullptr )
875875{
876876 void * d_temp_storage = nullptr ;
877877 size_t temp_storage_bytes = 0 ;
878878 GPUChkErrS (cub::DeviceScan::ExclusiveSum (d_temp_storage, temp_storage_bytes, in_out, in_out, num_items, stream));
879- GPUChkErrS (cudaMallocAsync (&d_temp_storage, temp_storage_bytes, stream));
879+ if (alloc) {
880+ d_temp_storage = alloc->allocate (temp_storage_bytes);
881+ } else {
882+ GPUChkErrS (cudaMallocAsync (&d_temp_storage, temp_storage_bytes, stream));
883+ }
880884 GPUChkErrS (cub::DeviceScan::ExclusiveSum (d_temp_storage, temp_storage_bytes, in_out, in_out, num_items, stream));
881- GPUChkErrS (cudaFreeAsync (d_temp_storage, stream));
885+ if (alloc) {
886+ alloc->deallocate (reinterpret_cast <char *>(d_temp_storage), temp_storage_bytes);
887+ } else {
888+ GPUChkErrS (cudaFreeAsync (d_temp_storage, stream));
889+ }
882890}
883891
884892template <typename Vector>
885- GPUhi () void cubExclusiveScanInPlace (Vector& in_out, int num_items, cudaStream_t stream = nullptr )
893+ GPUhi () void cubExclusiveScanInPlace (Vector& in_out, int num_items, cudaStream_t stream = nullptr , ExternalAllocator* alloc = nullptr )
886894{
887- cubExclusiveScanInPlace (thrust::raw_pointer_cast (in_out.data ()), num_items, stream);
895+ cubExclusiveScanInPlace (thrust::raw_pointer_cast (in_out.data ()), num_items, stream, alloc );
888896}
889897
890898template <typename T>
891- GPUhi () void cubInclusiveScanInPlace (T* in_out, int num_items, cudaStream_t stream = nullptr )
899+ GPUhi () void cubInclusiveScanInPlace (T* in_out, int num_items, cudaStream_t stream = nullptr , ExternalAllocator* alloc = nullptr )
892900{
893901 void * d_temp_storage = nullptr ;
894902 size_t temp_storage_bytes = 0 ;
895903 GPUChkErrS (cub::DeviceScan::InclusiveSum (d_temp_storage, temp_storage_bytes, in_out, in_out, num_items, stream));
896- GPUChkErrS (cudaMallocAsync (&d_temp_storage, temp_storage_bytes, stream));
904+ if (alloc) {
905+ d_temp_storage = alloc->allocate (temp_storage_bytes);
906+ } else {
907+ GPUChkErrS (cudaMallocAsync (&d_temp_storage, temp_storage_bytes, stream));
908+ }
897909 GPUChkErrS (cub::DeviceScan::InclusiveSum (d_temp_storage, temp_storage_bytes, in_out, in_out, num_items, stream));
898- GPUChkErrS (cudaFreeAsync (d_temp_storage, stream));
910+ if (alloc) {
911+ alloc->deallocate (reinterpret_cast <char *>(d_temp_storage), temp_storage_bytes);
912+ } else {
913+ GPUChkErrS (cudaFreeAsync (d_temp_storage, stream));
914+ }
899915}
900916
901917template <typename Vector>
902- GPUhi () void cubInclusiveScanInPlace (Vector& in_out, int num_items, cudaStream_t stream = nullptr )
918+ GPUhi () void cubInclusiveScanInPlace (Vector& in_out, int num_items, cudaStream_t stream = nullptr , ExternalAllocator* alloc = nullptr )
903919{
904- cubInclusiveScanInPlace (thrust::raw_pointer_cast (in_out.data ()), num_items, stream);
920+ cubInclusiveScanInPlace (thrust::raw_pointer_cast (in_out.data ()), num_items, stream, alloc );
905921}
906922} // namespace gpu
907923
@@ -932,6 +948,7 @@ void countTrackletsInROFsHandler(const IndexTableUtils* utils,
932948 bounded_vector<float >& resolutions,
933949 std::vector<float >& radii,
934950 bounded_vector<float >& mulScatAng,
951+ o2::its::ExternalAllocator* alloc,
935952 const int nBlocks,
936953 const int nThreads,
937954 gpu::Streams& streams)
@@ -964,7 +981,7 @@ void countTrackletsInROFsHandler(const IndexTableUtils* utils,
964981 resolutions[iLayer],
965982 radii[iLayer + 1 ] - radii[iLayer],
966983 mulScatAng[iLayer]);
967- gpu::cubExclusiveScanInPlace (trackletsLUTsHost[iLayer], nClusters[iLayer] + 1 , streams[iLayer].get ());
984+ gpu::cubExclusiveScanInPlace (trackletsLUTsHost[iLayer], nClusters[iLayer] + 1 , streams[iLayer].get (), alloc );
968985 }
969986}
970987
@@ -998,6 +1015,7 @@ void computeTrackletsInROFsHandler(const IndexTableUtils* utils,
9981015 bounded_vector<float >& resolutions,
9991016 std::vector<float >& radii,
10001017 bounded_vector<float >& mulScatAng,
1018+ o2::its::ExternalAllocator* alloc,
10011019 const int nBlocks,
10021020 const int nThreads,
10031021 gpu::Streams& streams)
@@ -1043,7 +1061,7 @@ void computeTrackletsInROFsHandler(const IndexTableUtils* utils,
10431061 spanTracklets[iLayer],
10441062 trackletsLUTsHost[iLayer],
10451063 nTracklets[iLayer]);
1046- gpu::cubExclusiveScanInPlace (trackletsLUTsHost[iLayer], nClusters[iLayer] + 1 , streams[iLayer].get ());
1064+ gpu::cubExclusiveScanInPlace (trackletsLUTsHost[iLayer], nClusters[iLayer] + 1 , streams[iLayer].get (), alloc );
10471065 }
10481066 }
10491067}
@@ -1064,6 +1082,7 @@ void countCellsHandler(
10641082 const float maxChi2ClusterAttachment,
10651083 const float cellDeltaTanLambdaSigma,
10661084 const float nSigmaCut,
1085+ o2::its::ExternalAllocator* alloc,
10671086 const int nBlocks,
10681087 const int nThreads,
10691088 gpu::Streams& streams)
@@ -1083,7 +1102,7 @@ void countCellsHandler(
10831102 maxChi2ClusterAttachment, // const float
10841103 cellDeltaTanLambdaSigma, // const float
10851104 nSigmaCut); // const float
1086- gpu::cubExclusiveScanInPlace (cellsLUTsHost, nTracklets + 1 , streams[layer].get ());
1105+ gpu::cubExclusiveScanInPlace (cellsLUTsHost, nTracklets + 1 , streams[layer].get (), alloc );
10871106}
10881107
10891108void computeCellsHandler (
@@ -1136,6 +1155,7 @@ void countCellNeighboursHandler(CellSeed** cellsLayersDevice,
11361155 const unsigned int nCells,
11371156 const unsigned int nCellsNext,
11381157 const int maxCellNeighbours,
1158+ o2::its::ExternalAllocator* alloc,
11391159 const int nBlocks,
11401160 const int nThreads,
11411161 gpu::Stream& stream)
@@ -1153,8 +1173,8 @@ void countCellNeighboursHandler(CellSeed** cellsLayersDevice,
11531173 layerIndex,
11541174 nCells,
11551175 maxCellNeighbours);
1156- gpu::cubInclusiveScanInPlace (neighboursLUT, nCellsNext, stream.get ());
1157- gpu::cubExclusiveScanInPlace (neighboursIndexTable, nCells + 1 , stream.get ());
1176+ gpu::cubInclusiveScanInPlace (neighboursLUT, nCellsNext, stream.get (), alloc );
1177+ gpu::cubExclusiveScanInPlace (neighboursIndexTable, nCells + 1 , stream.get (), alloc );
11581178}
11591179
11601180void computeCellNeighboursHandler (CellSeed** cellsLayersDevice,
@@ -1219,19 +1239,18 @@ void processNeighboursHandler(const int startLayer,
12191239 gsl::span<int *> neighboursDeviceLUTs,
12201240 const TrackingFrameInfo** foundTrackingFrameInfo,
12211241 bounded_vector<CellSeed>& seedsHost,
1222- o2::its::ExternalAllocator* allocator,
12231242 const float bz,
12241243 const float maxChi2ClusterAttachment,
12251244 const float maxChi2NDF,
12261245 const o2::base::Propagator* propagator,
12271246 const o2::base::PropagatorF::MatCorrType matCorrType,
1247+ o2::its::ExternalAllocator* alloc,
12281248 const int nBlocks,
12291249 const int nThreads)
12301250{
1231- auto allocInt = gpu::TypedAllocator<int >(allocator);
1232- auto allocCellSeed = gpu::TypedAllocator<CellSeed>(allocator);
1233- thrust::device_vector<int , gpu::TypedAllocator<int >> foundSeedsTable (nCells[startLayer] + 1 , 0 , allocInt); // Shortcut: device_vector skips central memory management, we are relying on the contingency.
1234- // TODO: fix this.
1251+ auto allocInt = gpu::TypedAllocator<int >(alloc);
1252+ auto allocCellSeed = gpu::TypedAllocator<CellSeed>(alloc);
1253+ thrust::device_vector<int , gpu::TypedAllocator<int >> foundSeedsTable (nCells[startLayer] + 1 , 0 , allocInt);
12351254
12361255 gpu::processNeighboursKernel<true ><<<nBlocks, nThreads>>> (
12371256 startLayer,
@@ -1251,7 +1270,7 @@ void processNeighboursHandler(const int startLayer,
12511270 maxChi2ClusterAttachment,
12521271 propagator,
12531272 matCorrType);
1254- gpu::cubExclusiveScanInPlace (foundSeedsTable, nCells[startLayer] + 1 );
1273+ gpu::cubExclusiveScanInPlace (foundSeedsTable, nCells[startLayer] + 1 , gpu::Stream::DefaultStream, alloc );
12551274
12561275 thrust::device_vector<int , gpu::TypedAllocator<int >> updatedCellId (foundSeedsTable.back (), 0 , allocInt);
12571276 thrust::device_vector<CellSeed, gpu::TypedAllocator<CellSeed>> updatedCellSeed (foundSeedsTable.back (), allocCellSeed);
@@ -1306,7 +1325,7 @@ void processNeighboursHandler(const int startLayer,
13061325 maxChi2ClusterAttachment,
13071326 propagator,
13081327 matCorrType);
1309- gpu::cubExclusiveScanInPlace (foundSeedsTable, foundSeedsTable.size ());
1328+ gpu::cubExclusiveScanInPlace (foundSeedsTable, foundSeedsTable.size (), gpu::Stream::DefaultStream, alloc );
13101329
13111330 auto foundSeeds{foundSeedsTable.back ()};
13121331 updatedCellId.resize (foundSeeds);
@@ -1402,6 +1421,7 @@ template void countTrackletsInROFsHandler<7>(const IndexTableUtils* utils,
14021421 bounded_vector<float >& resolutions,
14031422 std::vector<float >& radii,
14041423 bounded_vector<float >& mulScatAng,
1424+ o2::its::ExternalAllocator* alloc,
14051425 const int nBlocks,
14061426 const int nThreads,
14071427 gpu::Streams& streams);
@@ -1435,6 +1455,7 @@ template void computeTrackletsInROFsHandler<7>(const IndexTableUtils* utils,
14351455 bounded_vector<float >& resolutions,
14361456 std::vector<float >& radii,
14371457 bounded_vector<float >& mulScatAng,
1458+ o2::its::ExternalAllocator* alloc,
14381459 const int nBlocks,
14391460 const int nThreads,
14401461 gpu::Streams& streams);
@@ -1449,12 +1470,12 @@ template void processNeighboursHandler<7>(const int startLayer,
14491470 gsl::span<int *> neighboursDeviceLUTs,
14501471 const TrackingFrameInfo** foundTrackingFrameInfo,
14511472 bounded_vector<CellSeed>& seedsHost,
1452- o2::its::ExternalAllocator*,
14531473 const float bz,
14541474 const float maxChi2ClusterAttachment,
14551475 const float maxChi2NDF,
14561476 const o2::base::Propagator* propagator,
14571477 const o2::base::PropagatorF::MatCorrType matCorrType,
1478+ o2::its::ExternalAllocator* alloc,
14581479 const int nBlocks,
14591480 const int nThreads);
14601481} // namespace o2::its
0 commit comments