Skip to content

Commit 10279d5

Browse files
authored
Add MLselection in candidateSelectorXic0ToXiPiKf.cxx
1 parent 4a4ce3d commit 10279d5

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

PWGHF/TableProducer/candidateSelectorXic0ToXiPiKf.cxx

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
/// \file candidateSelectorXic0ToXiPiKf.cxx
1313
/// \brief Xic0 → Xi Pi selection task
1414
/// \author Ran Tu <ran.tu@cern.ch>, Fudan University
15+
/// \author Tao Fang <tao.fang@cern.ch>, Central China Normal University
16+
17+
#include <string>
18+
#include <vector>
1519

1620
#include "CommonConstants/PhysicsConstants.h"
1721
#include "Framework/AnalysisTask.h"
@@ -20,6 +24,9 @@
2024
#include "Common/Core/TrackSelection.h"
2125
#include "Common/Core/TrackSelectorPID.h"
2226

27+
#include "PWGHF/Core/HfHelper.h"
28+
#include "PWGHF/Core/HfMlResponseXicToXiPikf.h"
29+
2330
#include "PWGHF/DataModel/CandidateReconstructionTables.h"
2431
#include "PWGHF/DataModel/CandidateSelectionTables.h"
2532
#include "PWGHF/Utils/utilsAnalysis.h"
@@ -39,6 +46,7 @@ enum PidInfoStored {
3946
/// Struct for applying Xic0 -> Xi pi selection cuts
4047
struct HfCandidateSelectorXic0ToXiPiKf {
4148
Produces<aod::HfSelToXiPiKf> hfSelToXiPi;
49+
Produces<aod::HfMlToXiPikf> hfMlToXiPi;
4250

4351
// LF analysis selections
4452
Configurable<double> radiusCascMin{"radiusCascMin", 0.5, "Min cascade radius"};
@@ -115,7 +123,26 @@ struct HfCandidateSelectorXic0ToXiPiKf {
115123
Configurable<int> nClustersItsMin{"nClustersItsMin", 3, "Minimum number of ITS clusters requirement for pi <- charm baryon"};
116124
Configurable<int> nClustersItsInnBarrMin{"nClustersItsInnBarrMin", 1, "Minimum number of ITS clusters in inner barrel requirement for pi <- charm baryon"};
117125
Configurable<float> itsChi2PerClusterMax{"itsChi2PerClusterMax", 36, "Maximum value of chi2 fit over ITS clusters for pi <- charm baryon"};
118-
126+
127+
// ML inference
128+
Configurable<bool> applyMl{"applyMl", true, "Flag to apply ML selections"};
129+
Configurable<std::vector<double>> binsPtMl{"binsPtMl", std::vector<double>{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application"};
130+
Configurable<std::vector<int>> cutDirMl{"cutDirMl", std::vector<int>{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"};
131+
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
132+
Configurable<int> nClassesMl{"nClassesMl", static_cast<int>(hf_cuts_ml::NCutScores), "Number of classes in ML model"};
133+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};
134+
135+
// CCDB configuration
136+
Configurable<std::string> ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
137+
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"EventFiltering/PWGHF/BDTXic"}, "Paths of models on CCDB"};
138+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ModelHandler_onnx_XicToXipikf.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
139+
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
140+
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
141+
142+
o2::analysis::HfMlResponseXicToXiPikf<float> hfMlResponse;
143+
std::vector<float> outputMlXicToXiPi = {};
144+
o2::ccdb::CcdbApi ccdbApi;
145+
119146
TrackSelectorPr selectorProton;
120147
TrackSelectorPi selectorPion;
121148

@@ -176,6 +203,19 @@ struct HfCandidateSelectorXic0ToXiPiKf {
176203
registry.add("hSelMassCharmBaryon", "hSelMassCharmBaryon;status;entries", {HistType::kTH1D, {axisSel}});
177204
registry.add("hSelDcaXYToPvV0Daughters", "hSelDcaXYToPvV0Daughters;status;entries", {HistType::kTH1D, {axisSel}});
178205
registry.add("hSelDcaXYToPvPiFromCasc", "hSelDcaXYToPvPiFromCasc;status;entries", {HistType::kTH1D, {axisSel}});
206+
207+
if (applyMl) {
208+
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
209+
if (loadModelsFromCCDB) {
210+
ccdbApi.init(ccdbUrl);
211+
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
212+
} else {
213+
hfMlResponse.setModelPathsLocal(onnxFileNames);
214+
}
215+
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
216+
hfMlResponse.init();
217+
}
218+
179219
}
180220

181221
void process(aod::HfCandToXiPiKf const& candidates,
@@ -185,6 +225,10 @@ struct HfCandidateSelectorXic0ToXiPiKf {
185225

186226
// looping over charm baryon candidates
187227
for (const auto& candidate : candidates) {
228+
229+
//auto ptCandXic = candidate.kfptXic();
230+
auto ptCand = RecoDecay::sqrtSumOfSquares(candidate.pxCharmBaryon(), candidate.pyCharmBaryon());
231+
//pxCharmBaryon
188232

189233
bool resultSelections = true; // True if the candidate passes all the selections, False otherwise
190234

@@ -503,6 +547,18 @@ struct HfCandidateSelectorXic0ToXiPiKf {
503547
registry.fill(HIST("hSelMassCharmBaryon"), 0);
504548
}
505549

550+
// ML selections
551+
if (applyMl) {
552+
bool isSelectedMlXic = false;
553+
std::vector<float> inputFeaturesXic = hfMlResponse.getInputFeatures(candidate, trackPiFromLam, trackPiFromCasc, trackPiFromCharm);
554+
isSelectedMlXic = hfMlResponse.isSelectedMl(inputFeaturesXic, ptCand, outputMlXicToXiPi);
555+
hfMlToXiPi(outputMlXicToXiPi);
556+
557+
if (!isSelectedMlXic) {
558+
continue;
559+
}
560+
}
561+
506562
hfSelToXiPi(statusPidCharmBaryon, statusPidCascade, statusPidLambda, statusInvMassCharmBaryon, statusInvMassCascade, statusInvMassLambda, resultSelections, infoTpcStored, infoTofStored,
507563
trackPiFromCharm.tpcNSigmaPi(), trackPiFromCasc.tpcNSigmaPi(), trackPiFromLam.tpcNSigmaPi(), trackPrFromLam.tpcNSigmaPr(),
508564
trackPiFromCharm.tofNSigmaPi(), trackPiFromCasc.tofNSigmaPi(), trackPiFromLam.tofNSigmaPi(), trackPrFromLam.tofNSigmaPr());

0 commit comments

Comments
 (0)