3333#include < Framework/InitContext.h>
3434#include < Framework/runDataProcessing.h>
3535
36+ #include < algorithm>
3637#include < cstdint>
3738#include < cstdlib>
39+ #include < vector>
3840
3941using namespace o2 ;
4042using namespace o2 ::framework;
@@ -68,8 +70,8 @@ DECLARE_SOA_COLUMN(DecayLength, decayLength, float);
6870DECLARE_SOA_COLUMN (DecayLengthXY, decayLengthXY, float );
6971DECLARE_SOA_COLUMN (DecayLengthNormalised, decayLengthNormalised, float );
7072DECLARE_SOA_COLUMN (DecayLengthXYNormalised, decayLengthXYNormalised, float );
71- DECLARE_SOA_COLUMN (CPA , cpa, float );
72- DECLARE_SOA_COLUMN (CPAXY , cpaXY, float );
73+ DECLARE_SOA_COLUMN (Cpa , cpa, float );
74+ DECLARE_SOA_COLUMN (CpaXY , cpaXY, float );
7375DECLARE_SOA_COLUMN (Ct, ct, float );
7476DECLARE_SOA_COLUMN (PtV0Pos, ptV0Pos, float );
7577DECLARE_SOA_COLUMN (PtV0Neg, ptV0Neg, float );
@@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float);
8486DECLARE_SOA_COLUMN (FlagMc, flagMc, int8_t );
8587DECLARE_SOA_COLUMN (OriginMcRec, originMcRec, int8_t );
8688DECLARE_SOA_COLUMN (OriginMcGen, originMcGen, int8_t );
89+ DECLARE_SOA_COLUMN (MlScoreFirstClass, mlScoreFirstClass, float );
90+ DECLARE_SOA_COLUMN (MlScoreSecondClass, mlScoreSecondClass, float );
91+ DECLARE_SOA_COLUMN (MlScoreThirdClass, mlScoreThirdClass, float );
8792// Events
8893DECLARE_SOA_COLUMN (IsEventReject, isEventReject, int );
8994DECLARE_SOA_COLUMN (RunNumber, runNumber, int );
@@ -118,15 +123,18 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE",
118123 full::NSigmaTOFPr0,
119124 full::M,
120125 full::Pt,
121- full::CPA ,
122- full::CPAXY ,
126+ full::Cpa ,
127+ full::CpaXY ,
123128 full::Ct,
124129 full::Eta,
125130 full::Phi,
126131 full::Y,
127132 full::E,
128133 full::FlagMc,
129- full::OriginMcRec);
134+ full::OriginMcRec,
135+ full::MlScoreFirstClass,
136+ full::MlScoreSecondClass,
137+ full::MlScoreThirdClass);
130138
131139DECLARE_SOA_TABLE (HfCandCascFulls, " AOD" , " HFCANDCASCFULL" ,
132140 collision::BCId,
@@ -188,15 +196,18 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
188196 full::M,
189197 full::Pt,
190198 full::P,
191- full::CPA ,
192- full::CPAXY ,
199+ full::Cpa ,
200+ full::CpaXY ,
193201 full::Ct,
194202 full::Eta,
195203 full::Phi,
196204 full::Y,
197205 full::E,
198206 full::FlagMc,
199- full::OriginMcRec);
207+ full::OriginMcRec,
208+ full::MlScoreFirstClass,
209+ full::MlScoreSecondClass,
210+ full::MlScoreThirdClass);
200211
201212DECLARE_SOA_TABLE (HfCandCascFullEs, " AOD" , " HFCANDCASCFULLE" ,
202213 collision::BCId,
@@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP {
228239 Configurable<float > ptMaxForDownSample{" ptMaxForDownSample" , 24 ., " Maximum pt for the application of the downsampling factor" };
229240 Configurable<bool > fillOnlySignal{" fillOnlySignal" , false , " Flag to fill derived tables with signal for ML trainings" };
230241 Configurable<bool > fillOnlyBackground{" fillOnlyBackground" , false , " Flag to fill derived tables with background for ML trainings" };
242+ Configurable<bool > applyMl{" applyMl" , false , " Whether ML was used in candidateSelectorLc" };
243+
244+ constexpr static float UndefValueFloat = -999 .f;
231245
232246 HfHelper hfHelper;
233247
234- Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1 ;
235248 using TracksWPid = soa::Join<aod::Tracks, aod::TracksPidPr>;
236249 using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandCascade, aod::HfCandCascadeMcRec, aod::HfSelLcToK0sP>>;
237-
238- Partition<SelectedCandidatesMc> recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t (0 );
239- Partition<SelectedCandidatesMc> recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t (0 );
250+ Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1 ;
240251
241252 void init (InitContext const &)
242253 {
243254 }
244255
256+ // / \brief function to get ML score values for the current candidate and assign them to input parameters
257+ // / \param candidate candidate instance
258+ // / \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate
259+ // / \param mlScoreFirstClass ML score for belonging to the first class
260+ // / \param mlScoreSecondClass ML score for belonging to the second class
261+ // / \param mlScoreThirdClass ML score for belonging to the third class
262+ void assignMlScores (aod::HfMlLcToK0sP::iterator const & candidateMlScore, float & mlScoreFirstClass, float & mlScoreSecondClass, float & mlScoreThirdClass)
263+ {
264+ std::vector<float > mlScores;
265+ std::copy (candidateMlScore.mlProbLcToK0sP ().begin (), candidateMlScore.mlProbLcToK0sP ().end (), std::back_inserter (mlScores));
266+
267+ constexpr int IndexFirstClass{0 };
268+ constexpr int IndexSecondClass{1 };
269+ constexpr int IndexThirdClass{2 };
270+ if (mlScores.size () == 0 ) {
271+ return ; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty
272+ }
273+ mlScoreFirstClass = mlScores.at (IndexFirstClass);
274+ mlScoreSecondClass = mlScores.at (IndexSecondClass);
275+ if (mlScores.size () > IndexThirdClass) {
276+ mlScoreThirdClass = mlScores.at (IndexThirdClass);
277+ }
278+ }
279+
245280 template <typename T, typename U>
246- void fillCandidate (const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec)
281+ void fillCandidate (const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const & candidateMlScore )
247282 {
283+
284+ float mlScoreFirstClass{UndefValueFloat};
285+ float mlScoreSecondClass{UndefValueFloat};
286+ float mlScoreThirdClass{UndefValueFloat};
287+
288+ if (applyMl) {
289+ assignMlScores (candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass);
290+ }
291+
248292 if (fillCandidateLiteTable) {
249293 rowCandidateLite (
250294 candidate.chi2PCA (),
@@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP {
283327 hfHelper.yLc (candidate),
284328 hfHelper.eLc (candidate),
285329 flagMc,
286- originMcRec);
330+ originMcRec,
331+ mlScoreFirstClass,
332+ mlScoreSecondClass,
333+ mlScoreThirdClass);
287334 } else {
288335 rowCandidateFull (
289336 bach.collision ().bcId (),
@@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP {
353400 hfHelper.yLc (candidate),
354401 hfHelper.eLc (candidate),
355402 flagMc,
356- originMcRec);
403+ originMcRec,
404+ mlScoreFirstClass,
405+ mlScoreSecondClass,
406+ mlScoreThirdClass);
357407 }
358408 }
359409 template <typename T>
@@ -370,52 +420,41 @@ struct HfTreeCreatorLcToK0sP {
370420 void processMc (aod::Collisions const & collisions,
371421 aod::McCollisions const &,
372422 SelectedCandidatesMc const & candidates,
423+ aod::HfMlLcToK0sP const & candidateMlScores,
373424 soa::Join<aod::McParticles, aod::HfCandCascadeMcGen> const & particles,
374425 TracksWPid const &)
375426 {
376427
428+ if (applyMl && candidateMlScores.size () == 0 ) {
429+ LOG (fatal) << " ML enabled but table with the ML scores is empty! Please check your configurables." ;
430+ return ;
431+ }
432+
377433 // Filling event properties
378434 rowCandidateFullEvents.reserve (collisions.size ());
379435 for (const auto & collision : collisions) {
380436 fillEvent (collision);
381437 }
382438
383- if (fillOnlySignal) {
384- if (fillCandidateLiteTable) {
385- rowCandidateLite.reserve (recSig.size ());
386- } else {
387- rowCandidateFull.reserve (recSig.size ());
388- }
389- for (const auto & candidate : recSig) {
390- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
391- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
392- }
393- } else if (fillOnlyBackground) {
394- if (fillCandidateLiteTable) {
395- rowCandidateLite.reserve (recBkg.size ());
396- } else {
397- rowCandidateFull.reserve (recBkg.size ());
398- }
399- for (const auto & candidate : recBkg) {
400- if (downSampleBkgFactor < 1 .) {
401- float pseudoRndm = candidate.ptProng0 () * 1000 . - static_cast <int64_t >(candidate.ptProng0 () * 1000 );
402- if (candidate.pt () < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
403- continue ;
404- }
405- }
406- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
407- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
408- }
439+ if (fillCandidateLiteTable) {
440+ rowCandidateLite.reserve (candidates.size ());
409441 } else {
410- // Filling candidate properties
411- if (fillCandidateLiteTable) {
412- rowCandidateLite.reserve (candidates.size ());
442+ rowCandidateFull.reserve (candidates.size ());
443+ }
444+
445+ int iCand{0 };
446+ for (const auto & candidate : candidates) {
447+ auto candidateMlScore = candidateMlScores.rawIteratorAt (iCand);
448+ ++iCand;
449+ auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
450+ const int flag = candidate.flagMcMatchRec ();
451+
452+ if (fillOnlySignal && flag != 0 ) {
453+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
454+ } else if (fillOnlyBackground && flag == 0 ) {
455+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
413456 } else {
414- rowCandidateFull.reserve (candidates.size ());
415- }
416- for (const auto & candidate : candidates) {
417- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
418- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
457+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
419458 }
420459 }
421460
@@ -439,9 +478,15 @@ struct HfTreeCreatorLcToK0sP {
439478
440479 void processData (aod::Collisions const & collisions,
441480 soa::Join<aod::HfCandCascade, aod::HfSelLcToK0sP> const & candidates,
481+ aod::HfMlLcToK0sP const & candidateMlScores,
442482 TracksWPid const &)
443483 {
444484
485+ if (applyMl && candidateMlScores.size () == 0 ) {
486+ LOG (fatal) << " ML enabled but table with the ML scores is empty! Please check your configurables." ;
487+ return ;
488+ }
489+
445490 // Filling event properties
446491 rowCandidateFullEvents.reserve (collisions.size ());
447492 for (const auto & collision : collisions) {
@@ -454,11 +499,15 @@ struct HfTreeCreatorLcToK0sP {
454499 } else {
455500 rowCandidateFull.reserve (candidates.size ());
456501 }
502+
503+ int iCand{0 };
457504 for (const auto & candidate : candidates) {
505+ auto candidateMlScore = candidateMlScores.rawIteratorAt (iCand);
506+ ++iCand;
458507 auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
459508 double pseudoRndm = bach.pt () * 1000 . - static_cast <int16_t >(bach.pt () * 1000 );
460509 if (candidate.isSelLcToK0sP () >= 1 && pseudoRndm < downSampleBkgFactor) {
461- fillCandidate (candidate, bach, 0 , 0 );
510+ fillCandidate (candidate, bach, 0 , 0 , candidateMlScore );
462511 }
463512 }
464513 }
0 commit comments