@@ -36,71 +36,84 @@ namespace o2::vertexing::device
3636{
3737namespace kernel
3838{
39- GPUg () void printKernel (o2::vertexing::DCAFitterN<2 >* ft)
39+ template <typename Fitter>
40+ GPUg () void printKernel (Fitter* ft)
4041{
4142 if (threadIdx .x == 0 ) {
42- printf (" =============== GPU DCA Fitter ================\n " );
43+ printf (" =============== GPU DCA Fitter %d prongs ================\n " , Fitter::getNProngs () );
4344 ft->print ();
44- printf (" ===============================================\n " );
45+ printf (" ========================================================= \n " );
4546 }
4647}
4748
48- GPUg () void processKernel (o2::vertexing::DCAFitterN<2 >* ft, o2::track::TrackParCov* t1, o2::track::TrackParCov* t2, int * res)
49+ template <typename Fitter, typename ... Tr>
50+ GPUg () void processKernel (Fitter* ft, int * res, Tr*... tracks)
4951{
50- *res = ft->process (*t1, *t2 );
52+ *res = ft->process (*tracks... );
5153}
5254} // namespace kernel
5355
54- void print (o2::vertexing::DCAFitterN<2 >* ft,
55- const int nBlocks,
56- const int nThreads)
56+ // / CPU handlers
57+ template <typename Fitter>
58+ void print (const int nBlocks,
59+ const int nThreads,
60+ Fitter& ft)
5761{
58- DCAFitterN< 2 > * ft_device;
59- gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&ft_device), sizeof (o2::vertexing::DCAFitterN< 2 > )));
60- gpuCheckError (cudaMemcpy (ft_device, ft, sizeof (o2::vertexing::DCAFitterN< 2 > ), cudaMemcpyHostToDevice));
62+ Fitter * ft_device;
63+ gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&ft_device), sizeof (Fitter )));
64+ gpuCheckError (cudaMemcpy (ft_device, & ft, sizeof (Fitter ), cudaMemcpyHostToDevice));
6165
6266 kernel::printKernel<<<nBlocks, nThreads>>> (ft_device);
6367
6468 gpuCheckError (cudaPeekAtLastError ());
6569 gpuCheckError (cudaDeviceSynchronize ());
6670}
6771
68- int process (o2::vertexing::DCAFitterN< 2 >* fitter,
69- o2::track::TrackParCov& track1 ,
70- o2::track::TrackParCov& track2 ,
71- const int nBlocks ,
72- const int nThreads )
72+ template < typename Fitter, class ... Tr>
73+ int process ( const int nBlocks ,
74+ const int nThreads ,
75+ Fitter& fitter ,
76+ Tr&... args )
7377{
74- DCAFitterN<2 >* ft_device;
75- o2::track::TrackParCov* t1_device;
76- o2::track::TrackParCov* t2_device;
78+ Fitter* ft_device;
79+ std::array<o2::track::TrackParCov*, Fitter::getNProngs ()> tracks_device;
7780 int result, *result_device;
7881
79- gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&ft_device), sizeof (o2::vertexing::DCAFitterN<2 >)));
80- gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&t1_device), sizeof (o2::track::TrackParCov)));
81- gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&t2_device), sizeof (o2::track::TrackParCov)));
82+ gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&ft_device), sizeof (Fitter)));
8283 gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&result_device), sizeof (int )));
8384
84- gpuCheckError (cudaMemcpy (ft_device, fitter, sizeof (o2::vertexing::DCAFitterN<2 >), cudaMemcpyHostToDevice));
85- gpuCheckError (cudaMemcpy (t1_device, &track1, sizeof (o2::track::TrackParCov), cudaMemcpyHostToDevice));
86- gpuCheckError (cudaMemcpy (t2_device, &track2, sizeof (o2::track::TrackParCov), cudaMemcpyHostToDevice));
85+ int iArg{0 };
86+ ([&] {
87+ gpuCheckError (cudaMalloc (reinterpret_cast <void **>(&(tracks_device[iArg])), sizeof (o2::track::TrackParCov)));
88+ gpuCheckError (cudaMemcpy (tracks_device[iArg], &args, sizeof (o2::track::TrackParCov), cudaMemcpyHostToDevice));
89+ ++iArg;
90+ }(),
91+ ...);
8792
88- kernel::processKernel<<<nBlocks, nThreads>>> (ft_device, t1_device, t2_device, result_device);
93+ gpuCheckError (cudaMemcpy (ft_device, &fitter, sizeof (Fitter), cudaMemcpyHostToDevice));
94+
95+ std::apply ([&](auto &&... args) { kernel::processKernel<<<nBlocks, nThreads>>> (ft_device, result_device, args...); }, tracks_device);
8996
9097 gpuCheckError (cudaPeekAtLastError ());
9198 gpuCheckError (cudaDeviceSynchronize ());
9299
93100 gpuCheckError (cudaMemcpy (&result, result_device, sizeof (int ), cudaMemcpyDeviceToHost));
94- gpuCheckError (cudaMemcpy (fitter, ft_device, sizeof (o2::vertexing::DCAFitterN<2 >), cudaMemcpyDeviceToHost));
95- gpuCheckError (cudaMemcpy (&track1, t1_device, sizeof (o2::track::TrackParCov), cudaMemcpyDeviceToHost));
96- gpuCheckError (cudaMemcpy (&track2, t2_device, sizeof (o2::track::TrackParCov), cudaMemcpyDeviceToHost));
97- gpuCheckError (cudaFree (ft_device));
98- gpuCheckError (cudaFree (t1_device));
99- gpuCheckError (cudaFree (t2_device));
101+ gpuCheckError (cudaMemcpy (&fitter, ft_device, sizeof (Fitter), cudaMemcpyDeviceToHost));
102+ iArg = 0 ;
103+ ([&] {
104+ gpuCheckError (cudaMemcpy (&args, tracks_device[iArg], sizeof (o2::track::TrackParCov), cudaMemcpyDeviceToHost));
105+ gpuCheckError (cudaFree (tracks_device[iArg]));
106+ ++iArg;
107+ }(),
108+ ...);
100109
101110 gpuCheckError (cudaFree (result_device));
102111
103112 return result;
104113}
105114
115+ template int process (const int , const int , o2::vertexing::DCAFitterN<2 >&, o2::track::TrackParCov&, o2::track::TrackParCov&);
116+ template int process (const int , const int , o2::vertexing::DCAFitterN<3 >&, o2::track::TrackParCov&, o2::track::TrackParCov&, o2::track::TrackParCov&);
117+ template void print (const int , const int , o2::vertexing::DCAFitterN<2 >&);
118+ template void print (const int , const int , o2::vertexing::DCAFitterN<3 >&);
106119} // namespace o2::vertexing::device
0 commit comments