3737#include " PWGCF/FemtoDream/Core/femtoDreamUtils.h"
3838
3939#include " PWGHF/Core/HfHelper.h"
40+ #include " PWGHF/Core/HfMlResponseLcToPKPi.h"
4041#include " PWGHF/DataModel/CandidateReconstructionTables.h"
4142#include " PWGHF/DataModel/CandidateSelectionTables.h"
4243#include " PWGHF/Utils/utilsBfieldCCDB.h"
4344#include " PWGHF/Utils/utilsEvSelHf.h"
4445#include " PWGHF/Core/CentralityEstimation.h"
46+ #include " PWGHF/Core/SelectorCuts.h"
4547
4648using namespace o2 ;
4749using namespace o2 ::framework;
50+ using namespace o2 ::analysis;
4851using namespace o2 ::framework::expressions;
4952using namespace o2 ::analysis::femtoDream;
5053using namespace o2 ::hf_evsel;
@@ -60,6 +63,13 @@ enum Event : uint8_t {
6063 kPairSelected
6164};
6265
66+ // ml modes
67+ enum MlMode : uint8_t {
68+ kNoMl = 0 ,
69+ kFillMlFromSelector ,
70+ kFillMlFromNewBDT
71+ };
72+
6373struct HfFemtoDreamProducer {
6474
6575 Produces<aod::FDCollisions> outputCollision;
@@ -82,6 +92,11 @@ struct HfFemtoDreamProducer {
8292 Configurable<std::string> ccdbPathGrp{" ccdbPathGrp" , " GLO/GRP/GRP" , " Path of the grp file (Run 2)" };
8393 Configurable<std::string> ccdbPathGrpMag{" ccdbPathGrpMag" , " GLO/Config/GRPMagField" , " CCDB path of the GRPMagField object (Run 3)" };
8494
95+ Configurable<std::vector<std::string>> modelPathsCCDB{" modelPathsCCDB" , std::vector<std::string>{" EventFiltering/PWGHF/BDTLc" }, " Paths of models on CCDB" };
96+ Configurable<std::vector<std::string>> onnxFileNames{" onnxFileNames" , std::vector<std::string>{" ModelHandler_onnx_LcToPKPi.onnx" }, " ONNX file names for each pT bin (if not from CCDB full path)" };
97+ Configurable<int64_t > timestampCCDB{" timestampCCDB" , -1 , " timestamp of the ONNX file for ML model used to query in CCDB" };
98+ Configurable<bool > loadModelsFromCCDB{" loadModelsFromCCDB" , false , " Flag to enable or disable the loading of models from CCDB" };
99+
85100 // Configurable<bool> isForceGRP{"isForceGRP", false, "Set true if the magnetic field configuration is not available in the usual CCDB directory (e.g. for Run 2 converted data or unanchorad Monte Carlo)"};
86101
87102 Configurable<bool > isDebug{" isDebug" , true , " Enable Debug tables" };
@@ -108,47 +123,56 @@ struct HfFemtoDreamProducer {
108123 Configurable<std::vector<float >> trkTPCsCls{FemtoDreamTrackSelection::getSelectionName (femtoDreamTrackSelection::kTPCsClsMax , " trk" ), std::vector<float >{0 .1f , 160 .f }, FemtoDreamTrackSelection::getSelectionHelper (femtoDreamTrackSelection::kTPCsClsMax , " Track selection: " )};
109124 Configurable<std::vector<float >> trkITSnclsIbMin{FemtoDreamTrackSelection::getSelectionName (femtoDreamTrackSelection::kITSnClsIbMin , " trk" ), std::vector<float >{-1 .f , 1 .f }, FemtoDreamTrackSelection::getSelectionHelper (femtoDreamTrackSelection::kITSnClsIbMin , " Track selection: " )};
110125 Configurable<std::vector<float >> trkITSnclsMin{FemtoDreamTrackSelection::getSelectionName (femtoDreamTrackSelection::kITSnClsMin , " trk" ), std::vector<float >{-1 .f , 2 .f , 4 .f }, FemtoDreamTrackSelection::getSelectionHelper (femtoDreamTrackSelection::kITSnClsMin , " Track selection: " )};
126+ // ML inference
127+ Configurable<int > applyMlMode{" applyMlMode" , 1 , " Occupancy estimation (None: 0, BDT model from Lc selector: 1, New BDT model on Top of Lc selector: 2)" };
128+ Configurable<std::vector<double >> binsPtMl{" binsPtMl" , std::vector<double >{hf_cuts_ml::vecBinsPt}, " pT bin limits for ML application" };
129+ Configurable<std::vector<int >> cutDirMl{" cutDirMl" , std::vector<int >{hf_cuts_ml::vecCutDir}, " Whether to reject score values greater or smaller than the threshold" };
130+ Configurable<LabeledArray<double >> cutsMl{" cutsMl" , {hf_cuts_ml::Cuts[0 ], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, " ML selections per pT bin" };
131+ Configurable<int > nClassesMl{" nClassesMl" , static_cast <int >(hf_cuts_ml::NCutScores), " Number of classes in ML model" };
132+ Configurable<std::vector<std::string>> namesInputFeatures{" namesInputFeatures" , std::vector<std::string>{" feature1" , " feature2" }, " Names of ML model input features" };
133+
134+ FemtoDreamTrackSelection trackCuts;
135+
136+ HfHelper hfHelper;
137+ o2::analysis::HfMlResponseLcToPKPi<float > hfMlResponse;
138+ std::vector<float > outputMlPKPi = {};
139+ std::vector<float > outputMlPiKP = {};
140+ o2::ccdb::CcdbApi ccdbApi;
141+ o2::hf_evsel::HfEventSelection hfEvSel;
142+ Service<o2::ccdb::BasicCCDBManager> ccdb; // / Accessing the CCDB
143+ o2::base::MatLayerCylSet* lut;
144+ // if (doPvRefit){ lut = o2::base::MatLayerCylSet::rectifyPtrFromFile(ccdb->get<o2::base::MatLayerCylSet>(ccdbPathLut));} //! may be it useful, will check later
111145
146+ float magField;
147+ int runNumber;
112148 using CandidateLc = soa::Join<aod::HfCand3Prong, aod::HfSelLc>;
113149 using CandidateLcMc = soa::Join<aod::HfCand3Prong, aod::HfSelLc, aod::HfCand3ProngMcRec>;
114150
115151 using FemtoFullCollision = soa::Join<aod::Collisions, aod::EvSels, aod::Mults, aod::CentFT0Ms>::iterator;
116152 using FemtoFullCollisionMc = soa::Join<aod::Collisions, aod::EvSels, aod::Mults, aod::CentFT0Ms, aod::McCollisionLabels>::iterator;
117153 using FemtoFullMcgenCollisions = soa::Join<aod::McCollisions, o2::aod::MultsExtraMC>;
118154 using FemtoFullMcgenCollision = FemtoFullMcgenCollisions::iterator;
119- using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe>;
155+ using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe, aod::TracksPidPi, aod::PidTpcTofFullPi, aod::TracksPidKa, aod::PidTpcTofFullKa, aod::TracksPidPr, aod::PidTpcTofFullPr >;
120156 using FemtoHFTrack = FemtoHFTracks::iterator;
121157 using FemtoHFMcTracks = soa::Join<aod::McTrackLabels, FemtoHFTracks>;
122158 using FemtoHFMcTrack = FemtoHFMcTracks::iterator;
123159
124160 using GeneratedMc = soa::Filtered<soa::Join<aod::McParticles, aod::HfCand3ProngMcGen>>;
125161
126- FemtoDreamTrackSelection trackCuts;
127-
128162 Filter filterSelectCandidateLc = (aod::hf_sel_candidate_lc::isSelLcToPKPi >= selectionFlagLc || aod::hf_sel_candidate_lc::isSelLcToPiKP >= selectionFlagLc);
129163
130164 HistogramRegistry qaRegistry{" QAHistos" , {}, OutputObjHandlingPolicy::AnalysisObject};
131- HistogramRegistry TrackRegistry{" Tracks" , {}, OutputObjHandlingPolicy::AnalysisObject};
132-
133- HfHelper hfHelper;
134- o2::hf_evsel::HfEventSelection hfEvSel;
135-
136- float magField;
137- int runNumber;
138-
139- Service<o2::ccdb::BasicCCDBManager> ccdb; // / Accessing the CCDB
140- o2::base::MatLayerCylSet* lut;
141- // if (doPvRefit){ lut = o2::base::MatLayerCylSet::rectifyPtrFromFile(ccdb->get<o2::base::MatLayerCylSet>(ccdbPathLut));} //! may be it useful, will check later
165+ HistogramRegistry trackRegistry{" Tracks" , {}, OutputObjHandlingPolicy::AnalysisObject};
142166
143167 void init (InitContext&)
144168 {
145- std::array<bool , 5 > processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessDataCharmHadWithML, doprocessMcCharmHadWithML , doprocessMcCharmHadGen};
169+ std::array<bool , 3 > processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessMcCharmHadGen};
146170 if (std::accumulate (processes.begin (), processes.end (), 0 ) != 1 ) {
147171 LOGP (fatal, " One and only one process function must be enabled at a time." );
148172 }
149173
150- int CutBits = 8 * sizeof (o2::aod::femtodreamparticle::cutContainerType);
151- TrackRegistry .add (" AnalysisQA/CutCounter" , " ; Bit; Counter" , kTH1F , {{CutBits + 1 , -0.5 , CutBits + 0.5 }});
174+ int cutBits = 8 * sizeof (o2::aod::femtodreamparticle::cutContainerType);
175+ trackRegistry .add (" AnalysisQA/CutCounter" , " ; Bit; Counter" , kTH1F , {{cutBits + 1 , -0.5 , cutBits + 0.5 }});
152176
153177 // event QA histograms
154178 constexpr int kEventTypes = kPairSelected + 1 ;
@@ -181,7 +205,7 @@ struct HfFemtoDreamProducer {
181205 trackCuts.setSelection (trkPIDnSigmaMax, femtoDreamTrackSelection::kPIDnSigmaMax , femtoDreamSelection::kAbsUpperLimit );
182206 trackCuts.setPIDSpecies (trkPIDspecies);
183207 trackCuts.setnSigmaPIDOffset (trkPIDnSigmaOffsetTPC, trkPIDnSigmaOffsetTOF);
184- trackCuts.init <aod::femtodreamparticle::ParticleType::kTrack , aod::femtodreamparticle::TrackType::kNoChild , aod::femtodreamparticle::cutContainerType>(&qaRegistry, &TrackRegistry );
208+ trackCuts.init <aod::femtodreamparticle::ParticleType::kTrack , aod::femtodreamparticle::TrackType::kNoChild , aod::femtodreamparticle::cutContainerType>(&qaRegistry, &trackRegistry );
185209
186210 runNumber = 0 ;
187211 magField = 0.0 ;
@@ -194,6 +218,18 @@ struct HfFemtoDreamProducer {
194218
195219 int64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now ().time_since_epoch ()).count ();
196220 ccdb->setCreatedNotAfter (now);
221+
222+ if (applyMlMode) {
223+ hfMlResponse.configure (binsPtMl, cutsMl, cutDirMl, nClassesMl);
224+ if (loadModelsFromCCDB) {
225+ ccdbApi.init (ccdbUrl);
226+ hfMlResponse.setModelPathsCCDB (onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
227+ } else {
228+ hfMlResponse.setModelPathsLocal (onnxFileNames);
229+ }
230+ hfMlResponse.cacheInputFeaturesIndices (namesInputFeatures);
231+ hfMlResponse.init ();
232+ }
197233 }
198234
199235 // / Function to retrieve the nominal magnetic field in kG (0.1T) and convert it directly to T
@@ -314,7 +350,7 @@ struct HfFemtoDreamProducer {
314350 // std::vector<int> tmpIDtrack; // this vector keeps track of the matching of the primary track table row <-> aod::track table global index
315351 bool fIsTrackFilled = false ;
316352
317- for (auto & track : tracks) {
353+ for (const auto & track : tracks) {
318354 // / if the most open selection criteria are not fulfilled there is no
319355 // / point looking further at the track
320356 if (!trackCuts.isSelectedMinimal (track)) {
@@ -395,27 +431,45 @@ struct HfFemtoDreamProducer {
395431 // Filling candidate properties
396432 rowCandCharmHad.reserve (sizeCand);
397433 bool isTrackFilled = false ;
434+ bool isSelectedMlLcToPKPi = true ;
435+ bool isSelectedMlLcToPiKP = true ;
398436 for (const auto & candidate : candidates) {
399- std::array<float , 3 > outputMlPKPi{-1 ., -1 ., -1 .};
400- std::array<float , 3 > outputMlPiKP{-1 ., -1 ., -1 .};
437+
438+ auto trackPos1 = candidate.template prong0_as <TrackType>(); // positive daughter (negative for the antiparticles)
439+ auto trackNeg = candidate.template prong1_as <TrackType>(); // negative daughter (positive for the antiparticles)
440+ auto trackPos2 = candidate.template prong2_as <TrackType>(); // positive daughter (negative for the antiparticles)
441+
401442 if constexpr (useCharmMl) {
402443 // / fill with ML information
403444 // / BDT index 0: bkg score; BDT index 1: prompt score; BDT index 2: non-prompt score
404- if (candidate.mlProbLcToPKPi ().size () > 0 ) {
405- outputMlPKPi.at (0 ) = candidate.mlProbLcToPKPi ()[0 ]; // / bkg score
406- outputMlPKPi.at (1 ) = candidate.mlProbLcToPKPi ()[1 ]; // / prompt score
407- outputMlPKPi.at (2 ) = candidate.mlProbLcToPKPi ()[2 ]; // / non-prompt score
408- }
409- if (candidate.mlProbLcToPiKP ().size () > 0 ) {
410- outputMlPiKP.at (0 ) = candidate.mlProbLcToPiKP ()[0 ]; // / bkg score
411- outputMlPiKP.at (1 ) = candidate.mlProbLcToPiKP ()[1 ]; // / prompt score
412- outputMlPiKP.at (2 ) = candidate.mlProbLcToPiKP ()[2 ]; // / non-prompt score
445+ if (applyMlMode == kFillMlFromSelector ) {
446+ if (candidate.mlProbLcToPKPi ().size () > 0 ) {
447+ outputMlPKPi.at (0 ) = candidate.mlProbLcToPKPi ()[0 ]; // / bkg score
448+ outputMlPKPi.at (1 ) = candidate.mlProbLcToPKPi ()[1 ]; // / prompt score
449+ outputMlPKPi.at (2 ) = candidate.mlProbLcToPKPi ()[2 ]; // / non-prompt score
450+ }
451+ if (candidate.mlProbLcToPiKP ().size () > 0 ) {
452+ outputMlPiKP.at (0 ) = candidate.mlProbLcToPiKP ()[0 ]; // / bkg score
453+ outputMlPiKP.at (1 ) = candidate.mlProbLcToPiKP ()[1 ]; // / prompt score
454+ outputMlPiKP.at (2 ) = candidate.mlProbLcToPiKP ()[2 ]; // / non-prompt score
455+ }
456+ } else if (applyMlMode == kFillMlFromNewBDT ) {
457+ isSelectedMlLcToPKPi = false ;
458+ isSelectedMlLcToPiKP = false ;
459+ if (candidate.mlProbLcToPKPi ().size () > 0 ) {
460+ std::vector<float > inputFeaturesLcToPKPi = hfMlResponse.getInputFeatures (candidate, trackPos1, trackNeg, trackPos2, true );
461+ isSelectedMlLcToPKPi = hfMlResponse.isSelectedMl (inputFeaturesLcToPKPi, candidate.pt (), outputMlPKPi);
462+ }
463+ if (candidate.mlProbLcToPiKP ().size () > 0 ) {
464+ std::vector<float > inputFeaturesLcToPiKP = hfMlResponse.getInputFeatures (candidate, trackPos1, trackNeg, trackPos2, false );
465+ isSelectedMlLcToPiKP = hfMlResponse.isSelectedMl (inputFeaturesLcToPiKP, candidate.pt (), outputMlPKPi);
466+ }
467+ if (!isSelectedMlLcToPKPi && !isSelectedMlLcToPiKP)
468+ continue ;
469+ } else {
470+ LOGF (fatal, " Please check your Ml configuration!!" );
413471 }
414472 }
415- auto trackPos1 = candidate.template prong0_as <TrackType>(); // positive daughter (negative for the antiparticles)
416- auto trackNeg = candidate.template prong1_as <TrackType>(); // negative daughter (positive for the antiparticles)
417- auto trackPos2 = candidate.template prong2_as <TrackType>(); // positive daughter (negative for the antiparticles)
418-
419473 auto fillTable = [&](int CandFlag,
420474 int FunctionSelection,
421475 float BDTScoreBkg,
0 commit comments