Skip to content

Commit e7020e3

Browse files
[PWGHF] Implement the additional BDT model application on top of the trigger BDT used in the Lc selector (#11020)
1 parent e26c748 commit e7020e3

File tree

2 files changed

+89
-35
lines changed

2 files changed

+89
-35
lines changed

PWGHF/HFC/TableProducer/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ o2physics_add_dpl_workflow(correlator-lc-hadrons
6666

6767
o2physics_add_dpl_workflow(femto-dream-producer
6868
SOURCES femtoDreamProducer.cxx
69-
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::EventFilteringUtils
69+
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::EventFilteringUtils O2Physics::MLCore
7070
COMPONENT_NAME Analysis)

PWGHF/HFC/TableProducer/femtoDreamProducer.cxx

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@
3737
#include "PWGCF/FemtoDream/Core/femtoDreamUtils.h"
3838

3939
#include "PWGHF/Core/HfHelper.h"
40+
#include "PWGHF/Core/HfMlResponseLcToPKPi.h"
4041
#include "PWGHF/DataModel/CandidateReconstructionTables.h"
4142
#include "PWGHF/DataModel/CandidateSelectionTables.h"
4243
#include "PWGHF/Utils/utilsBfieldCCDB.h"
4344
#include "PWGHF/Utils/utilsEvSelHf.h"
4445
#include "PWGHF/Core/CentralityEstimation.h"
46+
#include "PWGHF/Core/SelectorCuts.h"
4547

4648
using namespace o2;
4749
using namespace o2::framework;
50+
using namespace o2::analysis;
4851
using namespace o2::framework::expressions;
4952
using namespace o2::analysis::femtoDream;
5053
using namespace o2::hf_evsel;
@@ -60,6 +63,13 @@ enum Event : uint8_t {
6063
kPairSelected
6164
};
6265

66+
// ml modes
67+
enum MlMode : uint8_t {
68+
kNoMl = 0,
69+
kFillMlFromSelector,
70+
kFillMlFromNewBDT
71+
};
72+
6373
struct HfFemtoDreamProducer {
6474

6575
Produces<aod::FDCollisions> outputCollision;
@@ -82,6 +92,11 @@ struct HfFemtoDreamProducer {
8292
Configurable<std::string> ccdbPathGrp{"ccdbPathGrp", "GLO/GRP/GRP", "Path of the grp file (Run 2)"};
8393
Configurable<std::string> ccdbPathGrpMag{"ccdbPathGrpMag", "GLO/Config/GRPMagField", "CCDB path of the GRPMagField object (Run 3)"};
8494

95+
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"EventFiltering/PWGHF/BDTLc"}, "Paths of models on CCDB"};
96+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ModelHandler_onnx_LcToPKPi.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
97+
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
98+
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
99+
85100
// Configurable<bool> isForceGRP{"isForceGRP", false, "Set true if the magnetic field configuration is not available in the usual CCDB directory (e.g. for Run 2 converted data or unanchorad Monte Carlo)"};
86101

87102
Configurable<bool> isDebug{"isDebug", true, "Enable Debug tables"};
@@ -108,47 +123,56 @@ struct HfFemtoDreamProducer {
108123
Configurable<std::vector<float>> trkTPCsCls{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kTPCsClsMax, "trk"), std::vector<float>{0.1f, 160.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kTPCsClsMax, "Track selection: ")};
109124
Configurable<std::vector<float>> trkITSnclsIbMin{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kITSnClsIbMin, "trk"), std::vector<float>{-1.f, 1.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kITSnClsIbMin, "Track selection: ")};
110125
Configurable<std::vector<float>> trkITSnclsMin{FemtoDreamTrackSelection::getSelectionName(femtoDreamTrackSelection::kITSnClsMin, "trk"), std::vector<float>{-1.f, 2.f, 4.f}, FemtoDreamTrackSelection::getSelectionHelper(femtoDreamTrackSelection::kITSnClsMin, "Track selection: ")};
126+
// ML inference
127+
Configurable<int> applyMlMode{"applyMlMode", 1, "Occupancy estimation (None: 0, BDT model from Lc selector: 1, New BDT model on Top of Lc selector: 2)"};
128+
Configurable<std::vector<double>> binsPtMl{"binsPtMl", std::vector<double>{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application"};
129+
Configurable<std::vector<int>> cutDirMl{"cutDirMl", std::vector<int>{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"};
130+
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
131+
Configurable<int> nClassesMl{"nClassesMl", static_cast<int>(hf_cuts_ml::NCutScores), "Number of classes in ML model"};
132+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};
133+
134+
FemtoDreamTrackSelection trackCuts;
135+
136+
HfHelper hfHelper;
137+
o2::analysis::HfMlResponseLcToPKPi<float> hfMlResponse;
138+
std::vector<float> outputMlPKPi = {};
139+
std::vector<float> outputMlPiKP = {};
140+
o2::ccdb::CcdbApi ccdbApi;
141+
o2::hf_evsel::HfEventSelection hfEvSel;
142+
Service<o2::ccdb::BasicCCDBManager> ccdb; /// Accessing the CCDB
143+
o2::base::MatLayerCylSet* lut;
144+
// if (doPvRefit){ lut = o2::base::MatLayerCylSet::rectifyPtrFromFile(ccdb->get<o2::base::MatLayerCylSet>(ccdbPathLut));} //! may be it useful, will check later
111145

146+
float magField;
147+
int runNumber;
112148
using CandidateLc = soa::Join<aod::HfCand3Prong, aod::HfSelLc>;
113149
using CandidateLcMc = soa::Join<aod::HfCand3Prong, aod::HfSelLc, aod::HfCand3ProngMcRec>;
114150

115151
using FemtoFullCollision = soa::Join<aod::Collisions, aod::EvSels, aod::Mults, aod::CentFT0Ms>::iterator;
116152
using FemtoFullCollisionMc = soa::Join<aod::Collisions, aod::EvSels, aod::Mults, aod::CentFT0Ms, aod::McCollisionLabels>::iterator;
117153
using FemtoFullMcgenCollisions = soa::Join<aod::McCollisions, o2::aod::MultsExtraMC>;
118154
using FemtoFullMcgenCollision = FemtoFullMcgenCollisions::iterator;
119-
using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe>;
155+
using FemtoHFTracks = soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullDe, aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullDe, aod::TracksPidPi, aod::PidTpcTofFullPi, aod::TracksPidKa, aod::PidTpcTofFullKa, aod::TracksPidPr, aod::PidTpcTofFullPr>;
120156
using FemtoHFTrack = FemtoHFTracks::iterator;
121157
using FemtoHFMcTracks = soa::Join<aod::McTrackLabels, FemtoHFTracks>;
122158
using FemtoHFMcTrack = FemtoHFMcTracks::iterator;
123159

124160
using GeneratedMc = soa::Filtered<soa::Join<aod::McParticles, aod::HfCand3ProngMcGen>>;
125161

126-
FemtoDreamTrackSelection trackCuts;
127-
128162
Filter filterSelectCandidateLc = (aod::hf_sel_candidate_lc::isSelLcToPKPi >= selectionFlagLc || aod::hf_sel_candidate_lc::isSelLcToPiKP >= selectionFlagLc);
129163

130164
HistogramRegistry qaRegistry{"QAHistos", {}, OutputObjHandlingPolicy::AnalysisObject};
131-
HistogramRegistry TrackRegistry{"Tracks", {}, OutputObjHandlingPolicy::AnalysisObject};
132-
133-
HfHelper hfHelper;
134-
o2::hf_evsel::HfEventSelection hfEvSel;
135-
136-
float magField;
137-
int runNumber;
138-
139-
Service<o2::ccdb::BasicCCDBManager> ccdb; /// Accessing the CCDB
140-
o2::base::MatLayerCylSet* lut;
141-
// if (doPvRefit){ lut = o2::base::MatLayerCylSet::rectifyPtrFromFile(ccdb->get<o2::base::MatLayerCylSet>(ccdbPathLut));} //! may be it useful, will check later
165+
HistogramRegistry trackRegistry{"Tracks", {}, OutputObjHandlingPolicy::AnalysisObject};
142166

143167
void init(InitContext&)
144168
{
145-
std::array<bool, 5> processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessDataCharmHadWithML, doprocessMcCharmHadWithML, doprocessMcCharmHadGen};
169+
std::array<bool, 3> processes = {doprocessDataCharmHad, doprocessMcCharmHad, doprocessMcCharmHadGen};
146170
if (std::accumulate(processes.begin(), processes.end(), 0) != 1) {
147171
LOGP(fatal, "One and only one process function must be enabled at a time.");
148172
}
149173

150-
int CutBits = 8 * sizeof(o2::aod::femtodreamparticle::cutContainerType);
151-
TrackRegistry.add("AnalysisQA/CutCounter", "; Bit; Counter", kTH1F, {{CutBits + 1, -0.5, CutBits + 0.5}});
174+
int cutBits = 8 * sizeof(o2::aod::femtodreamparticle::cutContainerType);
175+
trackRegistry.add("AnalysisQA/CutCounter", "; Bit; Counter", kTH1F, {{cutBits + 1, -0.5, cutBits + 0.5}});
152176

153177
// event QA histograms
154178
constexpr int kEventTypes = kPairSelected + 1;
@@ -181,7 +205,7 @@ struct HfFemtoDreamProducer {
181205
trackCuts.setSelection(trkPIDnSigmaMax, femtoDreamTrackSelection::kPIDnSigmaMax, femtoDreamSelection::kAbsUpperLimit);
182206
trackCuts.setPIDSpecies(trkPIDspecies);
183207
trackCuts.setnSigmaPIDOffset(trkPIDnSigmaOffsetTPC, trkPIDnSigmaOffsetTOF);
184-
trackCuts.init<aod::femtodreamparticle::ParticleType::kTrack, aod::femtodreamparticle::TrackType::kNoChild, aod::femtodreamparticle::cutContainerType>(&qaRegistry, &TrackRegistry);
208+
trackCuts.init<aod::femtodreamparticle::ParticleType::kTrack, aod::femtodreamparticle::TrackType::kNoChild, aod::femtodreamparticle::cutContainerType>(&qaRegistry, &trackRegistry);
185209

186210
runNumber = 0;
187211
magField = 0.0;
@@ -194,6 +218,18 @@ struct HfFemtoDreamProducer {
194218

195219
int64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
196220
ccdb->setCreatedNotAfter(now);
221+
222+
if (applyMlMode) {
223+
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
224+
if (loadModelsFromCCDB) {
225+
ccdbApi.init(ccdbUrl);
226+
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
227+
} else {
228+
hfMlResponse.setModelPathsLocal(onnxFileNames);
229+
}
230+
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
231+
hfMlResponse.init();
232+
}
197233
}
198234

199235
/// Function to retrieve the nominal magnetic field in kG (0.1T) and convert it directly to T
@@ -314,7 +350,7 @@ struct HfFemtoDreamProducer {
314350
// std::vector<int> tmpIDtrack; // this vector keeps track of the matching of the primary track table row <-> aod::track table global index
315351
bool fIsTrackFilled = false;
316352

317-
for (auto& track : tracks) {
353+
for (const auto& track : tracks) {
318354
/// if the most open selection criteria are not fulfilled there is no
319355
/// point looking further at the track
320356
if (!trackCuts.isSelectedMinimal(track)) {
@@ -395,27 +431,45 @@ struct HfFemtoDreamProducer {
395431
// Filling candidate properties
396432
rowCandCharmHad.reserve(sizeCand);
397433
bool isTrackFilled = false;
434+
bool isSelectedMlLcToPKPi = true;
435+
bool isSelectedMlLcToPiKP = true;
398436
for (const auto& candidate : candidates) {
399-
std::array<float, 3> outputMlPKPi{-1., -1., -1.};
400-
std::array<float, 3> outputMlPiKP{-1., -1., -1.};
437+
438+
auto trackPos1 = candidate.template prong0_as<TrackType>(); // positive daughter (negative for the antiparticles)
439+
auto trackNeg = candidate.template prong1_as<TrackType>(); // negative daughter (positive for the antiparticles)
440+
auto trackPos2 = candidate.template prong2_as<TrackType>(); // positive daughter (negative for the antiparticles)
441+
401442
if constexpr (useCharmMl) {
402443
/// fill with ML information
403444
/// BDT index 0: bkg score; BDT index 1: prompt score; BDT index 2: non-prompt score
404-
if (candidate.mlProbLcToPKPi().size() > 0) {
405-
outputMlPKPi.at(0) = candidate.mlProbLcToPKPi()[0]; /// bkg score
406-
outputMlPKPi.at(1) = candidate.mlProbLcToPKPi()[1]; /// prompt score
407-
outputMlPKPi.at(2) = candidate.mlProbLcToPKPi()[2]; /// non-prompt score
408-
}
409-
if (candidate.mlProbLcToPiKP().size() > 0) {
410-
outputMlPiKP.at(0) = candidate.mlProbLcToPiKP()[0]; /// bkg score
411-
outputMlPiKP.at(1) = candidate.mlProbLcToPiKP()[1]; /// prompt score
412-
outputMlPiKP.at(2) = candidate.mlProbLcToPiKP()[2]; /// non-prompt score
445+
if (applyMlMode == kFillMlFromSelector) {
446+
if (candidate.mlProbLcToPKPi().size() > 0) {
447+
outputMlPKPi.at(0) = candidate.mlProbLcToPKPi()[0]; /// bkg score
448+
outputMlPKPi.at(1) = candidate.mlProbLcToPKPi()[1]; /// prompt score
449+
outputMlPKPi.at(2) = candidate.mlProbLcToPKPi()[2]; /// non-prompt score
450+
}
451+
if (candidate.mlProbLcToPiKP().size() > 0) {
452+
outputMlPiKP.at(0) = candidate.mlProbLcToPiKP()[0]; /// bkg score
453+
outputMlPiKP.at(1) = candidate.mlProbLcToPiKP()[1]; /// prompt score
454+
outputMlPiKP.at(2) = candidate.mlProbLcToPiKP()[2]; /// non-prompt score
455+
}
456+
} else if (applyMlMode == kFillMlFromNewBDT) {
457+
isSelectedMlLcToPKPi = false;
458+
isSelectedMlLcToPiKP = false;
459+
if (candidate.mlProbLcToPKPi().size() > 0) {
460+
std::vector<float> inputFeaturesLcToPKPi = hfMlResponse.getInputFeatures(candidate, trackPos1, trackNeg, trackPos2, true);
461+
isSelectedMlLcToPKPi = hfMlResponse.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlPKPi);
462+
}
463+
if (candidate.mlProbLcToPiKP().size() > 0) {
464+
std::vector<float> inputFeaturesLcToPiKP = hfMlResponse.getInputFeatures(candidate, trackPos1, trackNeg, trackPos2, false);
465+
isSelectedMlLcToPiKP = hfMlResponse.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlPKPi);
466+
}
467+
if (!isSelectedMlLcToPKPi && !isSelectedMlLcToPiKP)
468+
continue;
469+
} else {
470+
LOGF(fatal, "Please check your Ml configuration!!");
413471
}
414472
}
415-
auto trackPos1 = candidate.template prong0_as<TrackType>(); // positive daughter (negative for the antiparticles)
416-
auto trackNeg = candidate.template prong1_as<TrackType>(); // negative daughter (positive for the antiparticles)
417-
auto trackPos2 = candidate.template prong2_as<TrackType>(); // positive daughter (negative for the antiparticles)
418-
419473
auto fillTable = [&](int CandFlag,
420474
int FunctionSelection,
421475
float BDTScoreBkg,

0 commit comments

Comments
 (0)