1717// / \author Marcel Lesch <marcel.lesch@tum.de>, TUM
1818// / \author Alexandre Bigot <alexandre.bigot@cern.ch>, Strasbourg University
1919// / \author Biao Zhang <biao.zhang@cern.ch>, CCNU
20+ // / \author Antonio Palasciano <antonio.palasciano@cern.ch>, INFN Bari
2021
22+ #include < string>
2123#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
2224#include < onnxruntime/core/session/experimental_onnxruntime_cxx_api.h> // needed for HFFilterHelpers, to be fixed
2325#else
@@ -54,8 +56,12 @@ struct HfFilterPrepareMlSamples { // Main struct
5456
5557 // parameters for production of training samples
5658 Configurable<bool > fillSignal{" fillSignal" , true , " Flag to fill derived tables with signal for ML trainings" };
57- Configurable<bool > fillBackground{ " fillBackground " , true , " Flag to fill derived tables with background for ML trainings" };
59+ Configurable<bool > fillOnlyBackground{ " fillOnlyBackground " , true , " Flag to fill derived tables with background for ML trainings" };
5860 Configurable<float > downSampleBkgFactor{" downSampleBkgFactor" , 1 ., " Fraction of background candidates to keep for ML trainings" };
61+ Configurable<float > massSbLeftMin{" massSbLeftMin" , 1.72 , " Left Sideband Lower Minv limit 2 Prong" };
62+ Configurable<float > massSbLeftMax{" massSbLeftMax" , 1.78 , " Left Sideband Upper Minv limit 2 Prong" };
63+ Configurable<float > massSbRightMin{" massSbRightMin" , 1.94 , " Right Sideband Lower Minv limit 2 Prong" };
64+ Configurable<float > massSbRightMax{" massSbRightMax" , 1.98 , " Right Sideband Upper Minv limit 2 Prong" };
5965
6066 // CCDB configuration
6167 o2::ccdb::CcdbApi ccdbApi;
@@ -67,6 +73,9 @@ struct HfFilterPrepareMlSamples { // Main struct
6773 o2::base::Propagator::MatCorrType noMatCorr = o2::base::Propagator::MatCorrType::USEMatCorrNONE;
6874 int currentRun = 0 ; // needed to detect if the run changed and trigger update of calibrations etc.
6975
76+ // helper object
77+ HfFilterHelper helper;
78+
7079 void init (InitContext&)
7180 {
7281 ccdb->setURL (url.value );
@@ -76,14 +85,147 @@ struct HfFilterPrepareMlSamples { // Main struct
7685 ccdbApi.init (url);
7786 }
7887
88+ using BigTracksPID = soa::Join<aod::Tracks, aod::TracksExtra, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTOFFullPi, aod::pidTPCFullKa, aod::pidTOFFullKa, aod::pidTPCFullPr, aod::pidTOFFullPr>;
7989 using BigTracksMCPID = soa::Join<aod::Tracks, aod::TracksExtra, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTOFFullPi, aod::pidTPCFullKa, aod::pidTOFFullKa, aod::pidTPCFullPr, aod::pidTOFFullPr, aod::McTrackLabels>;
8090
81- void process (aod::Hf2Prongs const & cand2Prongs,
82- aod::Hf3Prongs const & cand3Prongs,
83- aod::McParticles const & mcParticles,
84- soa::Join<aod::Collisions, aod::McCollisionLabels> const & collisions,
85- BigTracksMCPID const &,
86- aod::BCsWithTimestamps const &)
91+ void processData2Prong (aod::Hf2Prongs const & cand2Prongs,
92+ aod::Collisions const & collisions,
93+ BigTracksPID const &,
94+ aod::BCsWithTimestamps const &)
95+ {
96+ for (const auto & cand2Prong : cand2Prongs) { // start loop over 2 prongs
97+
98+ auto thisCollId = cand2Prong.collisionId ();
99+ auto collision = collisions.rawIteratorAt (thisCollId);
100+ auto bc = collision.bc_as <aod::BCsWithTimestamps>();
101+
102+ if (currentRun != bc.runNumber ()) {
103+ o2::parameters::GRPMagField* grpo = ccdb->getForTimeStamp <o2::parameters::GRPMagField>(ccdbPathGrpMag, bc.timestamp ());
104+ o2::base::Propagator::initFieldFromGRP (grpo);
105+ currentRun = bc.runNumber ();
106+ }
107+
108+ auto trackPos = cand2Prong.prong0_as <BigTracksPID>(); // positive daughter
109+ auto trackNeg = cand2Prong.prong1_as <BigTracksPID>(); // negative daughter
110+
111+ auto trackParPos = getTrackPar (trackPos);
112+ auto trackParNeg = getTrackPar (trackNeg);
113+ o2::gpu::gpustd::array<float , 2 > dcaPos{trackPos.dcaXY (), trackPos.dcaZ ()};
114+ o2::gpu::gpustd::array<float , 2 > dcaNeg{trackNeg.dcaXY (), trackNeg.dcaZ ()};
115+ std::array<float , 3 > pVecPos{trackPos.pVector ()};
116+ std::array<float , 3 > pVecNeg{trackNeg.pVector ()};
117+ if (trackPos.collisionId () != thisCollId) {
118+ o2::base::Propagator::Instance ()->propagateToDCABxByBz ({collision.posX (), collision.posY (), collision.posZ ()}, trackParPos, 2 .f , noMatCorr, &dcaPos);
119+ getPxPyPz (trackParPos, pVecPos);
120+ }
121+ if (trackNeg.collisionId () != thisCollId) {
122+ o2::base::Propagator::Instance ()->propagateToDCABxByBz ({collision.posX (), collision.posY (), collision.posZ ()}, trackParNeg, 2 .f , noMatCorr, &dcaNeg);
123+ getPxPyPz (trackParNeg, pVecNeg);
124+ }
125+
126+ auto pVec2Prong = RecoDecay::pVec (pVecPos, pVecNeg);
127+ auto pt2Prong = RecoDecay::pt (pVec2Prong);
128+
129+ auto invMassD0 = RecoDecay::m (std::array{pVecPos, pVecNeg}, std::array{massPi, massKa});
130+ auto invMassD0bar = RecoDecay::m (std::array{pVecPos, pVecNeg}, std::array{massKa, massPi});
131+
132+ auto flag = RecoDecay::OriginType::None;
133+
134+ if (fillOnlyBackground && !(isCharmHadronMassInSbRegions (invMassD0, invMassD0bar, massSbLeftMin, massSbLeftMax) || (isCharmHadronMassInSbRegions (invMassD0, invMassD0bar, massSbRightMin, massSbRightMax))))
135+ continue ;
136+ float pseudoRndm = trackPos.pt () * 1000 . - static_cast <int64_t >(trackPos.pt () * 1000 );
137+ if (pseudoRndm < downSampleBkgFactor) {
138+ train2P (invMassD0, invMassD0bar, pt2Prong, trackParPos.getPt (), dcaPos[0 ], dcaPos[1 ], trackPos.tpcNSigmaPi (), trackPos.tpcNSigmaKa (), trackPos.tofNSigmaPi (), trackPos.tofNSigmaKa (),
139+ trackParNeg.getPt (), dcaNeg[0 ], dcaNeg[1 ], trackNeg.tpcNSigmaPi (), trackNeg.tpcNSigmaKa (), trackNeg.tofNSigmaPi (), trackNeg.tofNSigmaKa (), flag, true );
140+ }
141+ } // end loop over 2-prong candidates
142+ }
143+ PROCESS_SWITCH (HfFilterPrepareMlSamples, processData2Prong, " Store 2prong(D0) data tables" , true );
144+
145+ void processData3Prong (aod::Hf3Prongs const & cand3Prongs,
146+ aod::Collisions const & collisions,
147+ BigTracksPID const &,
148+ aod::BCsWithTimestamps const &)
149+ {
150+ for (const auto & cand3Prong : cand3Prongs) { // start loop over 2 prongs
151+
152+ auto thisCollId = cand3Prong.collisionId ();
153+ auto collision = collisions.rawIteratorAt (thisCollId);
154+ auto bc = collision.bc_as <aod::BCsWithTimestamps>();
155+
156+ if (currentRun != bc.runNumber ()) {
157+ o2::parameters::GRPMagField* grpo = ccdb->getForTimeStamp <o2::parameters::GRPMagField>(ccdbPathGrpMag, bc.timestamp ());
158+ o2::base::Propagator::initFieldFromGRP (grpo);
159+ currentRun = bc.runNumber ();
160+ }
161+
162+ auto trackFirst = cand3Prong.prong0_as <BigTracksPID>(); // first daughter
163+ auto trackSecond = cand3Prong.prong1_as <BigTracksPID>(); // second daughter
164+ auto trackThird = cand3Prong.prong2_as <BigTracksPID>(); // third daughter
165+ auto arrayDaughters = std::array{trackFirst, trackSecond, trackThird};
166+
167+ auto trackParFirst = getTrackPar (trackFirst);
168+ auto trackParSecond = getTrackPar (trackSecond);
169+ auto trackParThird = getTrackPar (trackThird);
170+ o2::gpu::gpustd::array<float , 2 > dcaFirst{trackFirst.dcaXY (), trackFirst.dcaZ ()};
171+ o2::gpu::gpustd::array<float , 2 > dcaSecond{trackSecond.dcaXY (), trackSecond.dcaZ ()};
172+ o2::gpu::gpustd::array<float , 2 > dcaThird{trackThird.dcaXY (), trackThird.dcaZ ()};
173+ std::array<float , 3 > pVecFirst{trackFirst.pVector ()};
174+ std::array<float , 3 > pVecSecond{trackSecond.pVector ()};
175+ std::array<float , 3 > pVecThird{trackThird.pVector ()};
176+ if (trackFirst.collisionId () != thisCollId) {
177+ o2::base::Propagator::Instance ()->propagateToDCABxByBz ({collision.posX (), collision.posY (), collision.posZ ()}, trackParFirst, 2 .f , noMatCorr, &dcaFirst);
178+ getPxPyPz (trackParFirst, pVecFirst);
179+ }
180+ if (trackSecond.collisionId () != thisCollId) {
181+ o2::base::Propagator::Instance ()->propagateToDCABxByBz ({collision.posX (), collision.posY (), collision.posZ ()}, trackParSecond, 2 .f , noMatCorr, &dcaSecond);
182+ getPxPyPz (trackParSecond, pVecSecond);
183+ }
184+ if (trackThird.collisionId () != thisCollId) {
185+ o2::base::Propagator::Instance ()->propagateToDCABxByBz ({collision.posX (), collision.posY (), collision.posZ ()}, trackParThird, 2 .f , noMatCorr, &dcaThird);
186+ getPxPyPz (trackParThird, pVecThird);
187+ }
188+
189+ auto pVec3Prong = RecoDecay::pVec (pVecFirst, pVecSecond, pVecThird);
190+ auto pt3Prong = RecoDecay::pt (pVec3Prong);
191+
192+ auto invMassDplus = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massPi, massKa, massPi});
193+
194+ auto invMassDsToKKPi = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massKa, massKa, massPi});
195+ auto invMassDsToPiKK = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massPi, massKa, massKa});
196+
197+ auto invMassLcToPKPi = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massProton, massKa, massPi});
198+ auto invMassLcToPiKP = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massPi, massKa, massProton});
199+
200+ auto invMassXicToPKPi = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massProton, massKa, massPi});
201+ auto invMassXicToPiKP = RecoDecay::m (std::array{pVecFirst, pVecSecond, pVecThird}, std::array{massPi, massKa, massProton});
202+
203+ float deltaMassKKFirst = -1 .f ;
204+ float deltaMassKKSecond = -1 .f ;
205+ if (TESTBIT (cand3Prong.hfflag (), o2::aod::hf_cand_3prong::DecayType::DsToKKPi)) {
206+ deltaMassKKFirst = std::abs (RecoDecay::m (std::array{pVecFirst, pVecSecond}, std::array{massKa, massKa}) - massPhi);
207+ deltaMassKKSecond = std::abs (RecoDecay::m (std::array{pVecThird, pVecSecond}, std::array{massKa, massKa}) - massPhi);
208+ }
209+ int8_t sign = 0 ;
210+ auto flag = RecoDecay::OriginType::None;
211+
212+ float pseudoRndm = trackFirst.pt () * 1000 . - static_cast <int64_t >(trackFirst.pt () * 1000 );
213+ if (pseudoRndm < downSampleBkgFactor) {
214+ train3P (invMassDplus, invMassDsToKKPi, invMassDsToPiKK, invMassLcToPKPi, invMassLcToPiKP, invMassXicToPKPi, invMassXicToPiKP, pt3Prong, deltaMassKKFirst, deltaMassKKSecond,
215+ trackParFirst.getPt (), dcaFirst[0 ], dcaFirst[1 ], trackFirst.tpcNSigmaPi (), trackFirst.tpcNSigmaKa (), trackFirst.tpcNSigmaPr (), trackFirst.tofNSigmaPi (), trackFirst.tofNSigmaKa (), trackFirst.tofNSigmaPr (),
216+ trackParSecond.getPt (), dcaSecond[0 ], dcaSecond[1 ], trackSecond.tpcNSigmaPi (), trackSecond.tpcNSigmaKa (), trackSecond.tpcNSigmaPr (), trackSecond.tofNSigmaPi (), trackSecond.tofNSigmaKa (), trackSecond.tofNSigmaPr (),
217+ trackParThird.getPt (), dcaThird[0 ], dcaThird[1 ], trackThird.tpcNSigmaPi (), trackThird.tpcNSigmaKa (), trackThird.tpcNSigmaPr (), trackThird.tofNSigmaPi (), trackThird.tofNSigmaKa (), trackThird.tofNSigmaPr (),
218+ flag, 0 , cand3Prong.hfflag (), 0 );
219+ }
220+ } // end loop over 3-prong candidates
221+ }
222+ PROCESS_SWITCH (HfFilterPrepareMlSamples, processData3Prong, " Store 3prong(D0)-data tables" , true );
223+
224+ void processMC2Prong (aod::Hf2Prongs const & cand2Prongs,
225+ aod::McParticles const & mcParticles,
226+ soa::Join<aod::Collisions, aod::McCollisionLabels> const & collisions,
227+ BigTracksMCPID const &,
228+ aod::BCsWithTimestamps const &)
87229 {
88230 for (const auto & cand2Prong : cand2Prongs) { // start loop over 2 prongs
89231
@@ -136,13 +278,19 @@ struct HfFilterPrepareMlSamples { // Main struct
136278 }
137279 }
138280
139- float pseudoRndm = trackPos.pt () * 1000 . - (int64_t )(trackPos.pt () * 1000 );
140- if ((fillSignal && indexRec > -1 ) || (fillBackground && indexRec < 0 && pseudoRndm < downSampleBkgFactor)) {
141- train2P (invMassD0, invMassD0bar, pt2Prong, trackParPos.getPt (), dcaPos[0 ], dcaPos[1 ], trackPos.tpcNSigmaPi (), trackPos.tpcNSigmaKa (), trackPos.tofNSigmaPi (), trackPos.tofNSigmaKa (),
142- trackParNeg.getPt (), dcaNeg[0 ], dcaNeg[1 ], trackNeg.tpcNSigmaPi (), trackNeg.tpcNSigmaKa (), trackNeg.tofNSigmaPi (), trackNeg.tofNSigmaKa (), flag, isInCorrectColl);
143- }
281+ train2P (invMassD0, invMassD0bar, pt2Prong, trackParPos.getPt (), dcaPos[0 ], dcaPos[1 ], trackPos.tpcNSigmaPi (), trackPos.tpcNSigmaKa (), trackPos.tofNSigmaPi (), trackPos.tofNSigmaKa (),
282+ trackParNeg.getPt (), dcaNeg[0 ], dcaNeg[1 ], trackNeg.tpcNSigmaPi (), trackNeg.tpcNSigmaKa (), trackNeg.tofNSigmaPi (), trackNeg.tofNSigmaKa (), flag, isInCorrectColl);
283+
144284 } // end loop over 2-prong candidates
285+ }
286+ PROCESS_SWITCH (HfFilterPrepareMlSamples, processMC2Prong, " Store 2 prong(D0) MC tables" , false );
145287
288+ void processMC3Prong (aod::Hf3Prongs const & cand3Prongs,
289+ aod::McParticles const & mcParticles,
290+ soa::Join<aod::Collisions, aod::McCollisionLabels> const & collisions,
291+ BigTracksMCPID const &,
292+ aod::BCsWithTimestamps const &)
293+ {
146294 for (const auto & cand3Prong : cand3Prongs) { // start loop over 3 prongs
147295
148296 auto thisCollId = cand3Prong.collisionId ();
@@ -243,16 +391,15 @@ struct HfFilterPrepareMlSamples { // Main struct
243391 }
244392 }
245393
246- float pseudoRndm = trackFirst.pt () * 1000 . - (int64_t )(trackFirst.pt () * 1000 );
247- if ((fillSignal && indexRec > -1 ) || (fillBackground && indexRec < 0 && pseudoRndm < downSampleBkgFactor)) {
248- train3P (invMassDplus, invMassDsToKKPi, invMassDsToPiKK, invMassLcToPKPi, invMassLcToPiKP, invMassXicToPKPi, invMassXicToPiKP, pt3Prong, deltaMassKKFirst, deltaMassKKSecond,
249- trackParFirst.getPt (), dcaFirst[0 ], dcaFirst[1 ], trackFirst.tpcNSigmaPi (), trackFirst.tpcNSigmaKa (), trackFirst.tpcNSigmaPr (), trackFirst.tofNSigmaPi (), trackFirst.tofNSigmaKa (), trackFirst.tofNSigmaPr (),
250- trackParSecond.getPt (), dcaSecond[0 ], dcaSecond[1 ], trackSecond.tpcNSigmaPi (), trackSecond.tpcNSigmaKa (), trackSecond.tpcNSigmaPr (), trackSecond.tofNSigmaPi (), trackSecond.tofNSigmaKa (), trackSecond.tofNSigmaPr (),
251- trackParThird.getPt (), dcaThird[0 ], dcaThird[1 ], trackThird.tpcNSigmaPi (), trackThird.tpcNSigmaKa (), trackThird.tpcNSigmaPr (), trackThird.tofNSigmaPi (), trackThird.tofNSigmaKa (), trackThird.tofNSigmaPr (),
252- flag, channel, cand3Prong.hfflag (), isInCorrectColl);
253- }
394+ train3P (invMassDplus, invMassDsToKKPi, invMassDsToPiKK, invMassLcToPKPi, invMassLcToPiKP, invMassXicToPKPi, invMassXicToPiKP, pt3Prong, deltaMassKKFirst, deltaMassKKSecond,
395+ trackParFirst.getPt (), dcaFirst[0 ], dcaFirst[1 ], trackFirst.tpcNSigmaPi (), trackFirst.tpcNSigmaKa (), trackFirst.tpcNSigmaPr (), trackFirst.tofNSigmaPi (), trackFirst.tofNSigmaKa (), trackFirst.tofNSigmaPr (),
396+ trackParSecond.getPt (), dcaSecond[0 ], dcaSecond[1 ], trackSecond.tpcNSigmaPi (), trackSecond.tpcNSigmaKa (), trackSecond.tpcNSigmaPr (), trackSecond.tofNSigmaPi (), trackSecond.tofNSigmaKa (), trackSecond.tofNSigmaPr (),
397+ trackParThird.getPt (), dcaThird[0 ], dcaThird[1 ], trackThird.tpcNSigmaPi (), trackThird.tpcNSigmaKa (), trackThird.tpcNSigmaPr (), trackThird.tofNSigmaPi (), trackThird.tofNSigmaKa (), trackThird.tofNSigmaPr (),
398+ flag, channel, cand3Prong.hfflag (), isInCorrectColl);
399+
254400 } // end loop over 3-prong candidates
255401 }
402+ PROCESS_SWITCH (HfFilterPrepareMlSamples, processMC3Prong, " Store 3 prong MC tables" , false );
256403};
257404
258405WorkflowSpec defineDataProcessing (ConfigContext const & cfg)
0 commit comments