Skip to content

Commit 52c6dd6

Browse files
Update femtoDreamProducer.cxx
1 parent ed12615 commit 52c6dd6

File tree

1 file changed

+73
-17
lines changed

1 file changed

+73
-17
lines changed

PWGHF/HFC/TableProducer/femtoDreamProducer.cxx

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@
3737
#include "PWGCF/FemtoDream/Core/femtoDreamUtils.h"
3838

3939
#include "PWGHF/Core/HfHelper.h"
40+
#include "PWGHF/Core/HfMlResponseLcToPKPi.h"
4041
#include "PWGHF/DataModel/CandidateReconstructionTables.h"
4142
#include "PWGHF/DataModel/CandidateSelectionTables.h"
4243
#include "PWGHF/Utils/utilsBfieldCCDB.h"
4344
#include "PWGHF/Utils/utilsEvSelHf.h"
4445
#include "PWGHF/Core/CentralityEstimation.h"
46+
#include "PWGHF/Core/SelectorCuts.h"
4547

4648
using namespace o2;
4749
using namespace o2::framework;
50+
using namespace o2::analysis;
4851
using namespace o2::framework::expressions;
4952
using namespace o2::analysis::femtoDream;
5053
using namespace o2::hf_evsel;
@@ -60,6 +63,13 @@ enum Event : uint8_t {
6063
kPairSelected
6164
};
6265

66+
// ml modes
67+
enum MlMode : uint8_t {
68+
kNoMl = 0,
69+
kFillMlFromSelector,
70+
kFillMlFromNewBDT
71+
};
72+
6373
struct HfFemtoDreamProducer {
6474

6575
Produces<aod::FDCollisions> outputCollision;
@@ -82,6 +92,11 @@ struct HfFemtoDreamProducer {
8292
Configurable<std::string> ccdbPathGrp{"ccdbPathGrp", "GLO/GRP/GRP", "Path of the grp file (Run 2)"};
8393
Configurable<std::string> ccdbPathGrpMag{"ccdbPathGrpMag", "GLO/Config/GRPMagField", "CCDB path of the GRPMagField object (Run 3)"};
8494

95+
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"EventFiltering/PWGHF/BDTLc"}, "Paths of models on CCDB"};
96+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ModelHandler_onnx_LcToPKPi.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
97+
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
98+
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
99+
85100
// Configurable<bool> isForceGRP{"isForceGRP", false, "Set true if the magnetic field configuration is not available in the usual CCDB directory (e.g. for Run 2 converted data or unanchorad Monte Carlo)"};
86101

87102
Configurable<bool> isDebug{"isDebug", true, "Enable Debug tables"};
@@ -108,6 +123,13 @@ struct HfFemtoDreamProducer {
108123
Configurable<std::vector<float>> trkTPCsCls{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kTPCsClsMax, "trk"), std::vector<float>{0.1f, 160.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kTPCsClsMax, "Track selection: ")};
109124
Configurable<std::vector<float>> trkITSnclsIbMin{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kITSnClsIbMin, "trk"), std::vector<float>{-1.f, 1.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kITSnClsIbMin, "Track selection: ")};
110125
Configurable<std::vector<float>> trkITSnclsMin{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kITSnClsMin, "trk"), std::vector<float>{-1.f, 2.f, 4.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kITSnClsMin, "Track selection: ")};
126+
// ML inference
127+
Configurable<int> applyMlMode{"applyMlMode", 1, "Occupancy estimation (None: 0, BDT model from Lc selector: 1, New BDT model on Top of Lc selector: 2)"};
128+
Configurable<std::vector<double>> binsPtMl{"binsPtMl", std::vector<double>{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application"};
129+
Configurable<std::vector<int>> cutDirMl{"cutDirMl", std::vector<int>{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"};
130+
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
131+
Configurable<int> nClassesMl{"nClassesMl", static_cast<int>(hf_cuts_ml::NCutScores), "Number of classes in ML model"};
132+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};
111133

112134
using CandidateLc = soa::Join<aod::HfCand3Prong, aod::HfSelLc>;
113135
using CandidateLcMc = soa::Join<aod::HfCand3Prong, aod::HfSelLc, aod::HfCand3ProngMcRec>;
@@ -116,7 +138,7 @@ struct HfFemtoDreamProducer {
116138
using FemtoFullCollisionMc = soa::Join<aod::Collisions, aod::EvSels, aod::Mults, aod::CentFT0Ms, aod::McCollisionLabels>::iterator;
117139
using FemtoFullMcgenCollisions = soa::Join<aod::McCollisions, o2::aod::MultsExtraMC>;
118140
using FemtoFullMcgenCollision = FemtoFullMcgenCollisions::iterator;
119-
using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe>;
141+
using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe, aod::TracksPidPi, aod::PidTpcTofFullPi, aod::TracksPidKa, aod::PidTpcTofFullKa, aod::TracksPidPr, aod::PidTpcTofFullPr>;
120142
using FemtoHFTrack = FemtoHFTracks::iterator;
121143
using FemtoHFMcTracks = soa::Join<aod::McTrackLabels, FemtoHFTracks>;
122144
using FemtoHFMcTrack = FemtoHFMcTracks::iterator;
@@ -131,6 +153,10 @@ struct HfFemtoDreamProducer {
131153
HistogramRegistry TrackRegistry{"Tracks", {}, OutputObjHandlingPolicy::AnalysisObject};
132154

133155
HfHelper hfHelper;
156+
o2::analysis::HfMlResponseLcToPKPi<float> hfMlResponse;
157+
std::vector<float> outputMlPKPi = {};
158+
std::vector<float> outputMlPiKP = {};
159+
o2::ccdb::CcdbApi ccdbApi;
134160
o2::hf_evsel::HfEventSelection hfEvSel;
135161

136162
float magField;
@@ -142,7 +168,7 @@ struct HfFemtoDreamProducer {
142168

143169
void init(InitContext&)
144170
{
145-
std::array<bool, 5> processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessDataCharmHadWithML, doprocessMcCharmHadWithML, doprocessMcCharmHadGen};
171+
std::array<bool, 3> processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessMcCharmHadGen};
146172
if (std::accumulate(processes.begin(), processes.end(), 0) != 1) {
147173
LOGP(fatal, "One and only one process function must be enabled at a time.");
148174
}
@@ -194,6 +220,18 @@ struct HfFemtoDreamProducer {
194220

195221
int64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
196222
ccdb->setCreatedNotAfter(now);
223+
224+
if (applyMlMode) {
225+
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
226+
if (loadModelsFromCCDB) {
227+
ccdbApi.init(ccdbUrl);
228+
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
229+
} else {
230+
hfMlResponse.setModelPathsLocal(onnxFileNames);
231+
}
232+
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
233+
hfMlResponse.init();
234+
}
197235
}
198236

199237
/// Function to retrieve the nominal magnetic field in kG (0.1T) and convert it directly to T
@@ -395,27 +433,45 @@ struct HfFemtoDreamProducer {
395433
// Filling candidate properties
396434
rowCandCharmHad.reserve(sizeCand);
397435
bool isTrackFilled = false;
436+
bool isSelectedMlLcToPKPi = true;
437+
bool isSelectedMlLcToPiKP = true;
398438
for (const auto& candidate : candidates) {
399-
std::array<float, 3> outputMlPKPi{-1., -1., -1.};
400-
std::array<float, 3> outputMlPiKP{-1., -1., -1.};
439+
440+
auto trackPos1 = candidate.template prong0_as<TrackType>(); // positive daughter (negative for the antiparticles)
441+
auto trackNeg = candidate.template prong1_as<TrackType>(); // negative daughter (positive for the antiparticles)
442+
auto trackPos2 = candidate.template prong2_as<TrackType>(); // positive daughter (negative for the antiparticles)
443+
401444
if constexpr (useCharmMl) {
402445
/// fill with ML information
403446
/// BDT index 0: bkg score; BDT index 1: prompt score; BDT index 2: non-prompt score
404-
if (candidate.mlProbLcToPKPi().size() > 0) {
405-
outputMlPKPi.at(0) = candidate.mlProbLcToPKPi()[0]; /// bkg score
406-
outputMlPKPi.at(1) = candidate.mlProbLcToPKPi()[1]; /// prompt score
407-
outputMlPKPi.at(2) = candidate.mlProbLcToPKPi()[2]; /// non-prompt score
408-
}
409-
if (candidate.mlProbLcToPiKP().size() > 0) {
410-
outputMlPiKP.at(0) = candidate.mlProbLcToPiKP()[0]; /// bkg score
411-
outputMlPiKP.at(1) = candidate.mlProbLcToPiKP()[1]; /// prompt score
412-
outputMlPiKP.at(2) = candidate.mlProbLcToPiKP()[2]; /// non-prompt score
447+
if (applyMlMode == kFillMlFromSelector) {
448+
if (candidate.mlProbLcToPKPi().size() > 0) {
449+
outputMlPKPi.at(0) = candidate.mlProbLcToPKPi()[0]; /// bkg score
450+
outputMlPKPi.at(1) = candidate.mlProbLcToPKPi()[1]; /// prompt score
451+
outputMlPKPi.at(2) = candidate.mlProbLcToPKPi()[2]; /// non-prompt score
452+
}
453+
if (candidate.mlProbLcToPiKP().size() > 0) {
454+
outputMlPiKP.at(0) = candidate.mlProbLcToPiKP()[0]; /// bkg score
455+
outputMlPiKP.at(1) = candidate.mlProbLcToPiKP()[1]; /// prompt score
456+
outputMlPiKP.at(2) = candidate.mlProbLcToPiKP()[2]; /// non-prompt score
457+
}
458+
} else if (applyMlMode == kFillMlFromNewBDT) {
459+
isSelectedMlLcToPKPi = false;
460+
isSelectedMlLcToPiKP = false;
461+
if (candidate.mlProbLcToPKPi().size() > 0) {
462+
std::vector<float> inputFeaturesLcToPKPi = hfMlResponse.getInputFeatures(candidate, trackPos1, trackNeg, trackPos2, true);
463+
isSelectedMlLcToPKPi = hfMlResponse.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlPKPi);
464+
}
465+
if (candidate.mlProbLcToPiKP().size() > 0) {
466+
std::vector<float> inputFeaturesLcToPiKP = hfMlResponse.getInputFeatures(candidate, trackPos1, trackNeg, trackPos2, false);
467+
isSelectedMlLcToPiKP = hfMlResponse.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlPKPi);
468+
}
469+
if (!isSelectedMlLcToPKPi && !isSelectedMlLcToPiKP)
470+
continue;
471+
} else {
472+
LOGF(fatal, "Please check your Ml configuration!!");
413473
}
414474
}
415-
auto trackPos1 = candidate.template prong0_as<TrackType>(); // positive daughter (negative for the antiparticles)
416-
auto trackNeg = candidate.template prong1_as<TrackType>(); // negative daughter (positive for the antiparticles)
417-
auto trackPos2 = candidate.template prong2_as<TrackType>(); // positive daughter (negative for the antiparticles)
418-
419475
auto fillTable = [&](int CandFlag,
420476
int FunctionSelection,
421477
float BDTScoreBkg,

0 commit comments

Comments
 (0)