Skip to content

Commit 1c7fc84

Browse files
authored
Make DCA fitter to work on GPU (#13510)
1 parent 60f0dd4 commit 1c7fc84

File tree

7 files changed

+405
-387
lines changed

7 files changed

+405
-387
lines changed

Common/DCAFitter/GPU/cuda/DCAFitterN.cu

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,71 +36,84 @@ namespace o2::vertexing::device
3636
{
3737
namespace kernel
3838
{
39-
GPUg() void printKernel(o2::vertexing::DCAFitterN<2>* ft)
39+
template <typename Fitter>
40+
GPUg() void printKernel(Fitter* ft)
4041
{
4142
if (threadIdx.x == 0) {
42-
printf(" =============== GPU DCA Fitter ================\n");
43+
printf(" =============== GPU DCA Fitter %d prongs ================\n", Fitter::getNProngs());
4344
ft->print();
44-
printf(" ===============================================\n");
45+
printf(" =========================================================\n");
4546
}
4647
}
4748

48-
GPUg() void processKernel(o2::vertexing::DCAFitterN<2>* ft, o2::track::TrackParCov* t1, o2::track::TrackParCov* t2, int* res)
49+
template <typename Fitter, typename... Tr>
50+
GPUg() void processKernel(Fitter* ft, int* res, Tr*... tracks)
4951
{
50-
*res = ft->process(*t1, *t2);
52+
*res = ft->process(*tracks...);
5153
}
5254
} // namespace kernel
5355

54-
void print(o2::vertexing::DCAFitterN<2>* ft,
55-
const int nBlocks,
56-
const int nThreads)
56+
/// CPU handlers
57+
template <typename Fitter>
58+
void print(const int nBlocks,
59+
const int nThreads,
60+
Fitter& ft)
5761
{
58-
DCAFitterN<2>* ft_device;
59-
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(o2::vertexing::DCAFitterN<2>)));
60-
gpuCheckError(cudaMemcpy(ft_device, ft, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyHostToDevice));
62+
Fitter* ft_device;
63+
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(Fitter)));
64+
gpuCheckError(cudaMemcpy(ft_device, &ft, sizeof(Fitter), cudaMemcpyHostToDevice));
6165

6266
kernel::printKernel<<<nBlocks, nThreads>>>(ft_device);
6367

6468
gpuCheckError(cudaPeekAtLastError());
6569
gpuCheckError(cudaDeviceSynchronize());
6670
}
6771

68-
int process(o2::vertexing::DCAFitterN<2>* fitter,
69-
o2::track::TrackParCov& track1,
70-
o2::track::TrackParCov& track2,
71-
const int nBlocks,
72-
const int nThreads)
72+
template <typename Fitter, class... Tr>
73+
int process(const int nBlocks,
74+
const int nThreads,
75+
Fitter& fitter,
76+
Tr&... args)
7377
{
74-
DCAFitterN<2>* ft_device;
75-
o2::track::TrackParCov* t1_device;
76-
o2::track::TrackParCov* t2_device;
78+
Fitter* ft_device;
79+
std::array<o2::track::TrackParCov*, Fitter::getNProngs()> tracks_device;
7780
int result, *result_device;
7881

79-
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(o2::vertexing::DCAFitterN<2>)));
80-
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t1_device), sizeof(o2::track::TrackParCov)));
81-
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t2_device), sizeof(o2::track::TrackParCov)));
82+
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(Fitter)));
8283
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&result_device), sizeof(int)));
8384

84-
gpuCheckError(cudaMemcpy(ft_device, fitter, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyHostToDevice));
85-
gpuCheckError(cudaMemcpy(t1_device, &track1, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
86-
gpuCheckError(cudaMemcpy(t2_device, &track2, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
85+
int iArg{0};
86+
([&] {
87+
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&(tracks_device[iArg])), sizeof(o2::track::TrackParCov)));
88+
gpuCheckError(cudaMemcpy(tracks_device[iArg], &args, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
89+
++iArg;
90+
}(),
91+
...);
8792

88-
kernel::processKernel<<<nBlocks, nThreads>>>(ft_device, t1_device, t2_device, result_device);
93+
gpuCheckError(cudaMemcpy(ft_device, &fitter, sizeof(Fitter), cudaMemcpyHostToDevice));
94+
95+
std::apply([&](auto&&... args) { kernel::processKernel<<<nBlocks, nThreads>>>(ft_device, result_device, args...); }, tracks_device);
8996

9097
gpuCheckError(cudaPeekAtLastError());
9198
gpuCheckError(cudaDeviceSynchronize());
9299

93100
gpuCheckError(cudaMemcpy(&result, result_device, sizeof(int), cudaMemcpyDeviceToHost));
94-
gpuCheckError(cudaMemcpy(fitter, ft_device, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyDeviceToHost));
95-
gpuCheckError(cudaMemcpy(&track1, t1_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
96-
gpuCheckError(cudaMemcpy(&track2, t2_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
97-
gpuCheckError(cudaFree(ft_device));
98-
gpuCheckError(cudaFree(t1_device));
99-
gpuCheckError(cudaFree(t2_device));
101+
gpuCheckError(cudaMemcpy(&fitter, ft_device, sizeof(Fitter), cudaMemcpyDeviceToHost));
102+
iArg = 0;
103+
([&] {
104+
gpuCheckError(cudaMemcpy(&args, tracks_device[iArg], sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
105+
gpuCheckError(cudaFree(tracks_device[iArg]));
106+
++iArg;
107+
}(),
108+
...);
100109

101110
gpuCheckError(cudaFree(result_device));
102111

103112
return result;
104113
}
105114

115+
template int process(const int, const int, o2::vertexing::DCAFitterN<2>&, o2::track::TrackParCov&, o2::track::TrackParCov&);
116+
template int process(const int, const int, o2::vertexing::DCAFitterN<3>&, o2::track::TrackParCov&, o2::track::TrackParCov&, o2::track::TrackParCov&);
117+
template void print(const int, const int, o2::vertexing::DCAFitterN<2>&);
118+
template void print(const int, const int, o2::vertexing::DCAFitterN<3>&);
106119
} // namespace o2::vertexing::device

0 commit comments

Comments
 (0)