3333#include < Framework/InitContext.h>
3434#include < Framework/runDataProcessing.h>
3535
36+ #include < algorithm>
3637#include < cstdint>
3738#include < cstdlib>
39+ #include < vector>
3840
3941using namespace o2 ;
4042using namespace o2 ::framework;
@@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float);
8486DECLARE_SOA_COLUMN (FlagMc, flagMc, int8_t );
8587DECLARE_SOA_COLUMN (OriginMcRec, originMcRec, int8_t );
8688DECLARE_SOA_COLUMN (OriginMcGen, originMcGen, int8_t );
89+ DECLARE_SOA_COLUMN (MlScoreFirstClass, mlScoreFirstClass, float );
90+ DECLARE_SOA_COLUMN (MlScoreSecondClass, mlScoreSecondClass, float );
91+ DECLARE_SOA_COLUMN (MlScoreThirdClass, mlScoreThirdClass, float );
8792// Events
8893DECLARE_SOA_COLUMN (IsEventReject, isEventReject, int );
8994DECLARE_SOA_COLUMN (RunNumber, runNumber, int );
@@ -126,7 +131,10 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE",
126131 full::Y,
127132 full::E,
128133 full::FlagMc,
129- full::OriginMcRec);
134+ full::OriginMcRec,
135+ full::MlScoreFirstClass,
136+ full::MlScoreSecondClass,
137+ full::MlScoreThirdClass);
130138
131139DECLARE_SOA_TABLE (HfCandCascFulls, " AOD" , " HFCANDCASCFULL" ,
132140 collision::BCId,
@@ -196,7 +204,10 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
196204 full::Y,
197205 full::E,
198206 full::FlagMc,
199- full::OriginMcRec);
207+ full::OriginMcRec,
208+ full::MlScoreFirstClass,
209+ full::MlScoreSecondClass,
210+ full::MlScoreThirdClass);
200211
201212DECLARE_SOA_TABLE (HfCandCascFullEs, " AOD" , " HFCANDCASCFULLE" ,
202213 collision::BCId,
@@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP {
228239 Configurable<float > ptMaxForDownSample{" ptMaxForDownSample" , 24 ., " Maximum pt for the application of the downsampling factor" };
229240 Configurable<bool > fillOnlySignal{" fillOnlySignal" , false , " Flag to fill derived tables with signal for ML trainings" };
230241 Configurable<bool > fillOnlyBackground{" fillOnlyBackground" , false , " Flag to fill derived tables with background for ML trainings" };
242+ Configurable<bool > applyMl{" applyMl" , false , " Whether ML was used in candidateSelectorLc" };
243+
244+ constexpr static float UndefValueFloat = -999 .f;
231245
232246 HfHelper hfHelper;
233247
234248 Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1 ;
235249 using TracksWPid = soa::Join<aod::Tracks, aod::TracksPidPr>;
236250 using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandCascade, aod::HfCandCascadeMcRec, aod::HfSelLcToK0sP>>;
237251
238- Partition<SelectedCandidatesMc> recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t (0 );
239- Partition<SelectedCandidatesMc> recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t (0 );
240-
241252 void init (InitContext const &)
242253 {
243254 }
244255
256+ // / \brief function to get ML score values for the current candidate and assign them to input parameters
257+ // / \param candidate candidate instance
258+ // / \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate
259+ // / \param mlScoreFirstClass ML score for belonging to the first class
260+ // / \param mlScoreSecondClass ML score for belonging to the second class
261+ // / \param mlScoreThirdClass ML score for belonging to the third class
262+ void assignMlScores (aod::HfMlLcToK0sP::iterator const & candidateMlScore, float & mlScoreFirstClass, float & mlScoreSecondClass, float & mlScoreThirdClass)
263+ {
264+ std::vector<float > mlScores;
265+ std::copy (candidateMlScore.mlProbLcToK0sP ().begin (), candidateMlScore.mlProbLcToK0sP ().end (), std::back_inserter (mlScores));
266+
267+ constexpr int IndexFirstClass{0 };
268+ constexpr int IndexSecondClass{1 };
269+ constexpr int IndexThirdClass{2 };
270+ if (mlScores.size () == 0 ) {
271+ return ; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty
272+ }
273+ mlScoreFirstClass = mlScores.at (IndexFirstClass);
274+ mlScoreSecondClass = mlScores.at (IndexSecondClass);
275+ if (mlScores.size () > IndexThirdClass) {
276+ mlScoreThirdClass = mlScores.at (IndexThirdClass);
277+ }
278+ }
279+
245280 template <typename T, typename U>
246- void fillCandidate (const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec)
281+ void fillCandidate (const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const & candidateMlScore )
247282 {
283+
284+ float mlScoreFirstClass{UndefValueFloat};
285+ float mlScoreSecondClass{UndefValueFloat};
286+ float mlScoreThirdClass{UndefValueFloat};
287+
288+ if (applyMl) {
289+ assignMlScores (candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass);
290+ }
291+
248292 if (fillCandidateLiteTable) {
249293 rowCandidateLite (
250294 candidate.chi2PCA (),
@@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP {
283327 hfHelper.yLc (candidate),
284328 hfHelper.eLc (candidate),
285329 flagMc,
286- originMcRec);
330+ originMcRec,
331+ mlScoreFirstClass,
332+ mlScoreSecondClass,
333+ mlScoreThirdClass);
287334 } else {
288335 rowCandidateFull (
289336 bach.collision ().bcId (),
@@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP {
353400 hfHelper.yLc (candidate),
354401 hfHelper.eLc (candidate),
355402 flagMc,
356- originMcRec);
403+ originMcRec,
404+ mlScoreFirstClass,
405+ mlScoreSecondClass,
406+ mlScoreThirdClass);
357407 }
358408 }
359409 template <typename T>
@@ -370,6 +420,7 @@ struct HfTreeCreatorLcToK0sP {
370420 void processMc (aod::Collisions const & collisions,
371421 aod::McCollisions const &,
372422 SelectedCandidatesMc const & candidates,
423+ aod::HfMlLcToK0sP const & candidateMlScores,
373424 soa::Join<aod::McParticles, aod::HfCandCascadeMcGen> const & particles,
374425 TracksWPid const &)
375426 {
@@ -380,42 +431,25 @@ struct HfTreeCreatorLcToK0sP {
380431 fillEvent (collision);
381432 }
382433
383- if (fillOnlySignal) {
384- if (fillCandidateLiteTable) {
385- rowCandidateLite.reserve (recSig.size ());
386- } else {
387- rowCandidateFull.reserve (recSig.size ());
388- }
389- for (const auto & candidate : recSig) {
390- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
391- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
392- }
393- } else if (fillOnlyBackground) {
394- if (fillCandidateLiteTable) {
395- rowCandidateLite.reserve (recBkg.size ());
396- } else {
397- rowCandidateFull.reserve (recBkg.size ());
398- }
399- for (const auto & candidate : recBkg) {
400- if (downSampleBkgFactor < 1 .) {
401- float pseudoRndm = candidate.ptProng0 () * 1000 . - static_cast <int64_t >(candidate.ptProng0 () * 1000 );
402- if (candidate.pt () < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
403- continue ;
404- }
405- }
406- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
407- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
408- }
434+ if (fillCandidateLiteTable) {
435+ rowCandidateLite.reserve (candidates.size ());
409436 } else {
410- // Filling candidate properties
411- if (fillCandidateLiteTable) {
412- rowCandidateLite.reserve (candidates.size ());
437+ rowCandidateFull.reserve (candidates.size ());
438+ }
439+
440+ int iCand{0 };
441+ for (const auto & candidate : candidates) {
442+ auto candidateMlScore = candidateMlScores.rawIteratorAt (iCand);
443+ ++iCand;
444+ auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
445+ const int flag = candidate.flagMcMatchRec ();
446+
447+ if (fillOnlySignal && flag != 0 ) {
448+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
449+ } else if (fillOnlyBackground && flag == 0 ) {
450+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
413451 } else {
414- rowCandidateFull.reserve (candidates.size ());
415- }
416- for (const auto & candidate : candidates) {
417- auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
418- fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec ());
452+ fillCandidate (candidate, bach, candidate.flagMcMatchRec (), candidate.originMcRec (), candidateMlScore);
419453 }
420454 }
421455
@@ -439,6 +473,7 @@ struct HfTreeCreatorLcToK0sP {
439473
440474 void processData (aod::Collisions const & collisions,
441475 soa::Join<aod::HfCandCascade, aod::HfSelLcToK0sP> const & candidates,
476+ aod::HfMlLcToK0sP const & candidateMlScores,
442477 TracksWPid const &)
443478 {
444479
@@ -454,11 +489,15 @@ struct HfTreeCreatorLcToK0sP {
454489 } else {
455490 rowCandidateFull.reserve (candidates.size ());
456491 }
492+
493+ int iCand{0 };
457494 for (const auto & candidate : candidates) {
495+ auto candidateMlScore = candidateMlScores.rawIteratorAt (iCand);
496+ ++iCand;
458497 auto bach = candidate.prong0_as <TracksWPid>(); // bachelor
459498 double pseudoRndm = bach.pt () * 1000 . - static_cast <int16_t >(bach.pt () * 1000 );
460499 if (candidate.isSelLcToK0sP () >= 1 && pseudoRndm < downSampleBkgFactor) {
461- fillCandidate (candidate, bach, 0 , 0 );
500+ fillCandidate (candidate, bach, 0 , 0 , candidateMlScore );
462501 }
463502 }
464503 }
0 commit comments