Skip to content

Commit 6c908cc

Browse files
AlexBigOAlexandre Bigot
andauthored
PWGHF: Implement B0 ML and templatize task (#4661)
* Implement B0 ML and templatize task * Solve MegaLinter issue * Solve MegaLinter without running clang format * Add missing lines to fill B0 ML scores histograms in processDataWithB0Ml * Apply Fabrizio's comments --------- Co-authored-by: Alexandre Bigot <abigot@sbgat402.in2p3.fr>
1 parent 730dd3b commit 6c908cc

5 files changed

Lines changed: 495 additions & 199 deletions

File tree

PWGHF/Core/HfMlResponseB0ToDPi.h

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file HfMlResponsB0ToDPi.h
13+
/// \brief Class to compute the ML response for B0 → D∓ π± analysis selections
14+
/// \author Alexandre Bigot <alexandre.bigot@cern.ch>, IPHC Strasbourg
15+
16+
#ifndef PWGHF_CORE_HFMLRESPONSEB0TODPI_H_
17+
#define PWGHF_CORE_HFMLRESPONSEB0TODPI_H_
18+
19+
#include <map>
20+
#include <string>
21+
#include <vector>
22+
23+
#include "PWGHF/Core/HfMlResponse.h"
24+
25+
// Fill the map of available input features
26+
// the key is the feature's name (std::string)
27+
// the value is the corresponding value in EnumInputFeatures
28+
#define FILL_MAP_B0(FEATURE) \
29+
{ \
30+
#FEATURE, static_cast < uint8_t>(InputFeaturesB0ToDPi::FEATURE) \
31+
}
32+
33+
// Check if the index of mCachedIndices (index associated to a FEATURE)
34+
// matches the entry in EnumInputFeatures associated to this FEATURE
35+
// if so, the inputFeatures vector is filled with the FEATURE's value
36+
// by calling the corresponding GETTER from OBJECT
37+
#define CHECK_AND_FILL_VEC_B0_FULL(OBJECT, FEATURE, GETTER) \
38+
case static_cast<uint8_t>(InputFeaturesB0ToDPi::FEATURE): { \
39+
inputFeatures.emplace_back(OBJECT.GETTER()); \
40+
break; \
41+
}
42+
43+
// Check if the index of mCachedIndices (index associated to a FEATURE)
44+
// matches the entry in EnumInputFeatures associated to this FEATURE
45+
// if so, the inputFeatures vector is filled with the FEATURE's value
46+
// by calling the GETTER function taking OBJECT in argument
47+
#define CHECK_AND_FILL_VEC_B0_FUNC(OBJECT, FEATURE, GETTER) \
48+
case static_cast<uint8_t>(InputFeaturesB0ToDPi::FEATURE): { \
49+
inputFeatures.emplace_back(GETTER(OBJECT)); \
50+
break; \
51+
}
52+
53+
// Specific case of CHECK_AND_FILL_VEC_B0_FULL(OBJECT, FEATURE, GETTER)
54+
// where OBJECT is named candidate and FEATURE = GETTER
55+
#define CHECK_AND_FILL_VEC_B0(GETTER) \
56+
case static_cast<uint8_t>(InputFeaturesB0ToDPi::GETTER): { \
57+
inputFeatures.emplace_back(candidate.GETTER()); \
58+
break; \
59+
}
60+
61+
namespace o2::pid_tpc_tof_utils
62+
{
63+
template <typename T1>
64+
float getTpcTofNSigmaPi1(const T1& prong1)
65+
{
66+
float defaultNSigma = -999.f; // -999.f is the default value set in TPCPIDResponse.h and PIDTOF.h
67+
68+
bool hasTpc = prong1.hasTPC();
69+
bool hasTof = prong1.hasTOF();
70+
71+
if (hasTpc && hasTof) {
72+
float tpcNSigma = prong1.tpcNSigmaPi();
73+
float tofNSigma = prong1.tofNSigmaPi();
74+
return sqrt(.5f * tpcNSigma * tpcNSigma + .5f * tofNSigma * tofNSigma);
75+
}
76+
if (hasTpc) {
77+
return abs(prong1.tpcNSigmaPi());
78+
}
79+
if (hasTof) {
80+
return abs(prong1.tofNSigmaPi());
81+
}
82+
return defaultNSigma;
83+
}
84+
} // namespace o2::pid_tpc_tof_utils
85+
86+
namespace o2::analysis
87+
{
88+
89+
enum class InputFeaturesB0ToDPi : uint8_t {
90+
ptProng0 = 0,
91+
ptProng1,
92+
impactParameter0,
93+
impactParameter1,
94+
impactParameterProduct,
95+
chi2PCA,
96+
decayLength,
97+
decayLengthXY,
98+
decayLengthNormalised,
99+
decayLengthXYNormalised,
100+
cpa,
101+
cpaXY,
102+
maxNormalisedDeltaIP,
103+
prong0MlScoreBkg,
104+
prong0MlScorePrompt,
105+
prong0MlScoreNonprompt,
106+
tpcNSigmaPi1,
107+
tofNSigmaPi1,
108+
tpcTofNSigmaPi1
109+
};
110+
111+
template <typename TypeOutputScore = float>
112+
class HfMlResponseB0ToDPi : public HfMlResponse<TypeOutputScore>
113+
{
114+
public:
115+
/// Default constructor
116+
HfMlResponseB0ToDPi() = default;
117+
/// Default destructor
118+
virtual ~HfMlResponseB0ToDPi() = default;
119+
120+
/// Method to get the input features vector needed for ML inference
121+
/// \param candidate is the B0 candidate
122+
/// \param prong1 is the candidate's prong1
123+
/// \return inputFeatures vector
124+
template <bool withDmesMl, typename T1, typename T2>
125+
std::vector<float> getInputFeatures(T1 const& candidate,
126+
T2 const& prong1)
127+
{
128+
std::vector<float> inputFeatures;
129+
130+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
131+
if constexpr (withDmesMl) {
132+
switch (idx) {
133+
CHECK_AND_FILL_VEC_B0(ptProng0);
134+
CHECK_AND_FILL_VEC_B0(ptProng1);
135+
CHECK_AND_FILL_VEC_B0(impactParameter0);
136+
CHECK_AND_FILL_VEC_B0(impactParameter1);
137+
CHECK_AND_FILL_VEC_B0(impactParameterProduct);
138+
CHECK_AND_FILL_VEC_B0(chi2PCA);
139+
CHECK_AND_FILL_VEC_B0(decayLength);
140+
CHECK_AND_FILL_VEC_B0(decayLengthXY);
141+
CHECK_AND_FILL_VEC_B0(decayLengthNormalised);
142+
CHECK_AND_FILL_VEC_B0(decayLengthXYNormalised);
143+
CHECK_AND_FILL_VEC_B0(cpa);
144+
CHECK_AND_FILL_VEC_B0(cpaXY);
145+
CHECK_AND_FILL_VEC_B0(maxNormalisedDeltaIP);
146+
CHECK_AND_FILL_VEC_B0(prong0MlScoreBkg);
147+
CHECK_AND_FILL_VEC_B0(prong0MlScorePrompt);
148+
CHECK_AND_FILL_VEC_B0(prong0MlScoreNonprompt);
149+
// TPC PID variable
150+
CHECK_AND_FILL_VEC_B0_FULL(prong1, tpcNSigmaPi1, tpcNSigmaPi);
151+
// TOF PID variable
152+
CHECK_AND_FILL_VEC_B0_FULL(prong1, tofNSigmaPi1, tofNSigmaPi);
153+
// Combined PID variables
154+
CHECK_AND_FILL_VEC_B0_FUNC(prong1, tpcTofNSigmaPi1, o2::pid_tpc_tof_utils::getTpcTofNSigmaPi1);
155+
}
156+
} else {
157+
switch (idx) {
158+
CHECK_AND_FILL_VEC_B0(ptProng0);
159+
CHECK_AND_FILL_VEC_B0(ptProng1);
160+
CHECK_AND_FILL_VEC_B0(impactParameter0);
161+
CHECK_AND_FILL_VEC_B0(impactParameter1);
162+
CHECK_AND_FILL_VEC_B0(impactParameterProduct);
163+
CHECK_AND_FILL_VEC_B0(chi2PCA);
164+
CHECK_AND_FILL_VEC_B0(decayLength);
165+
CHECK_AND_FILL_VEC_B0(decayLengthXY);
166+
CHECK_AND_FILL_VEC_B0(decayLengthNormalised);
167+
CHECK_AND_FILL_VEC_B0(decayLengthXYNormalised);
168+
CHECK_AND_FILL_VEC_B0(cpa);
169+
CHECK_AND_FILL_VEC_B0(cpaXY);
170+
CHECK_AND_FILL_VEC_B0(maxNormalisedDeltaIP);
171+
// TPC PID variable
172+
CHECK_AND_FILL_VEC_B0_FULL(prong1, tpcNSigmaPi1, tpcNSigmaPi);
173+
// TOF PID variable
174+
CHECK_AND_FILL_VEC_B0_FULL(prong1, tofNSigmaPi1, tofNSigmaPi);
175+
// Combined PID variables
176+
CHECK_AND_FILL_VEC_B0_FUNC(prong1, tpcTofNSigmaPi1, o2::pid_tpc_tof_utils::getTpcTofNSigmaPi1);
177+
}
178+
}
179+
}
180+
181+
return inputFeatures;
182+
}
183+
184+
protected:
185+
/// Method to fill the map of available input features
186+
void setAvailableInputFeatures()
187+
{
188+
MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
189+
FILL_MAP_B0(ptProng0),
190+
FILL_MAP_B0(ptProng1),
191+
FILL_MAP_B0(impactParameter0),
192+
FILL_MAP_B0(impactParameter1),
193+
FILL_MAP_B0(impactParameterProduct),
194+
FILL_MAP_B0(chi2PCA),
195+
FILL_MAP_B0(decayLength),
196+
FILL_MAP_B0(decayLengthXY),
197+
FILL_MAP_B0(decayLengthNormalised),
198+
FILL_MAP_B0(decayLengthXYNormalised),
199+
FILL_MAP_B0(cpa),
200+
FILL_MAP_B0(cpaXY),
201+
FILL_MAP_B0(maxNormalisedDeltaIP),
202+
FILL_MAP_B0(prong0MlScoreBkg),
203+
FILL_MAP_B0(prong0MlScorePrompt),
204+
FILL_MAP_B0(prong0MlScoreNonprompt),
205+
// TPC PID variable
206+
FILL_MAP_B0(tpcNSigmaPi1),
207+
// TOF PID variable
208+
FILL_MAP_B0(tofNSigmaPi1),
209+
// Combined PID variable
210+
FILL_MAP_B0(tpcTofNSigmaPi1)};
211+
}
212+
};
213+
214+
} // namespace o2::analysis
215+
216+
#undef FILL_MAP_B0
217+
#undef CHECK_AND_FILL_VEC_B0_FULL
218+
#undef CHECK_AND_FILL_VEC_B0_FUNC
219+
#undef CHECK_AND_FILL_VEC_B0
220+
221+
#endif // PWGHF_CORE_HFMLRESPONSEB0TODPI_H_

PWGHF/D2H/TableProducer/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ o2physics_add_dpl_workflow(candidate-creator-bplus-reduced
2525

2626
o2physics_add_dpl_workflow(candidate-selector-b0-to-d-pi-reduced
2727
SOURCES candidateSelectorB0ToDPiReduced.cxx
28-
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore
28+
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::MLCore
2929
COMPONENT_NAME Analysis)
3030

3131
o2physics_add_dpl_workflow(candidate-selector-bplus-to-d0-pi-reduced

PWGHF/D2H/TableProducer/candidateSelectorB0ToDPiReduced.cxx

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "Common/Core/TrackSelectorPID.h"
2222

2323
#include "PWGHF/Core/HfHelper.h"
24+
#include "PWGHF/Core/HfMlResponseB0ToDPi.h"
2425
#include "PWGHF/Core/SelectorCuts.h"
2526
#include "PWGHF/DataModel/CandidateReconstructionTables.h"
2627
#include "PWGHF/DataModel/CandidateSelectionTables.h"
@@ -33,6 +34,7 @@ using namespace o2::analysis;
3334

3435
struct HfCandidateSelectorB0ToDPiReduced {
3536
Produces<aod::HfSelB0ToDPi> hfSelB0ToDPiCandidate; // table defined in CandidateSelectionTables.h
37+
Produces<aod::HfMlB0ToDPi> hfMlB0ToDPiCandidate; // table defined in CandidateSelectionTables.h
3638

3739
Configurable<float> ptCandMin{"ptCandMin", 0., "Lower bound of candidate pT"};
3840
Configurable<float> ptCandMax{"ptCandMax", 50., "Upper bound of candidate pT"};
@@ -57,17 +59,34 @@ struct HfCandidateSelectorB0ToDPiReduced {
5759
Configurable<LabeledArray<double>> cutsDmesMl{"cutsDmesMl", {hf_cuts_ml::cuts[0], hf_cuts_ml::nBinsPt, hf_cuts_ml::nCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsDmesCutScore}, "D-meson ML cuts per pT bin"};
5860
// QA switch
5961
Configurable<bool> activateQA{"activateQA", false, "Flag to enable QA histogram"};
60-
62+
// B0 ML inference
63+
Configurable<bool> applyB0Ml{"applyB0Ml", false, "Flag to apply ML selections"};
64+
Configurable<std::vector<double>> binsPtB0Ml{"binsPtB0Ml", std::vector<double>{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application"};
65+
Configurable<std::vector<int>> cutDirB0Ml{"cutDirB0Ml", std::vector<int>{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"};
66+
Configurable<LabeledArray<double>> cutsB0Ml{"cutsB0Ml", {hf_cuts_ml::cuts[0], hf_cuts_ml::nBinsPt, hf_cuts_ml::nCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
67+
Configurable<int8_t> nClassesB0Ml{"nClassesB0Ml", (int8_t)hf_cuts_ml::nCutScores, "Number of classes in ML model"};
68+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};
69+
// CCDB configuration
70+
Configurable<std::string> ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
71+
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"path_ccdb/BDT_B0/"}, "Paths of models on CCDB"};
72+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ModelHandler_onnx_B0ToDPi.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
73+
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
74+
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
6175
// variable that will store the value of selectionFlagD (defined in dataCreatorDplusPiReduced.cxx)
6276
int mySelectionFlagD = -1;
6377

64-
HfHelper hfHelper;
65-
TrackSelectorPi selectorPion;
78+
o2::analysis::HfMlResponseB0ToDPi<float> hfMlResponse;
79+
float outputMlNotPreselected = -1.;
80+
std::vector<float> outputMl = {};
81+
o2::ccdb::CcdbApi ccdbApi;
6682

67-
HistogramRegistry registry{"registry"};
83+
TrackSelectorPi selectorPion;
84+
HfHelper hfHelper;
6885

6986
using TracksPion = soa::Join<HfRedTracks, HfRedTracksPid>;
7087

88+
HistogramRegistry registry{"registry"};
89+
7190
void init(InitContext const& initContext)
7291
{
7392
std::array<bool, 2> doprocess{doprocessSelection, doprocessSelectionWithDmesMl};
@@ -101,6 +120,18 @@ struct HfCandidateSelectorB0ToDPiReduced {
101120
registry.get<TH2>(HIST("hSelections"))->GetXaxis()->SetBinLabel(iBin + 1, labels[iBin].data());
102121
}
103122
}
123+
124+
if (applyB0Ml) {
125+
hfMlResponse.configure(binsPtB0Ml, cutsB0Ml, cutDirB0Ml, nClassesB0Ml);
126+
if (loadModelsFromCCDB) {
127+
ccdbApi.init(ccdbUrl);
128+
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
129+
} else {
130+
hfMlResponse.setModelPathsLocal(onnxFileNames);
131+
}
132+
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
133+
hfMlResponse.init();
134+
}
104135
}
105136

106137
/// Main function to perform B0 candidate creation
@@ -125,6 +156,9 @@ struct HfCandidateSelectorB0ToDPiReduced {
125156
// check if flagged as B0 → D π
126157
if (!TESTBIT(hfCandB0.hfflag(), hf_cand_b0::DecayType::B0ToDPi)) {
127158
hfSelB0ToDPiCandidate(statusB0ToDPi);
159+
if (applyB0Ml) {
160+
hfMlB0ToDPiCandidate(outputMlNotPreselected);
161+
}
128162
if (activateQA) {
129163
registry.fill(HIST("hSelections"), 1, ptCandB0);
130164
}
@@ -139,13 +173,19 @@ struct HfCandidateSelectorB0ToDPiReduced {
139173
// topological cuts
140174
if (!hfHelper.selectionB0ToDPiTopol(hfCandB0, cuts, binsPt)) {
141175
hfSelB0ToDPiCandidate(statusB0ToDPi);
176+
if (applyB0Ml) {
177+
hfMlB0ToDPiCandidate(outputMlNotPreselected);
178+
}
142179
// LOGF(info, "B0 candidate selection failed at topology selection");
143180
continue;
144181
}
145182

146183
if constexpr (withDmesMl) { // we include it in the topological selections
147184
if (!hfHelper.selectionDmesMlScoresForB(hfCandB0, cutsDmesMl, binsPtDmesMl)) {
148185
hfSelB0ToDPiCandidate(statusB0ToDPi);
186+
if (applyB0Ml) {
187+
hfMlB0ToDPiCandidate(outputMlNotPreselected);
188+
}
149189
// LOGF(info, "B0 candidate selection failed at D-meson ML selection");
150190
continue;
151191
}
@@ -157,8 +197,8 @@ struct HfCandidateSelectorB0ToDPiReduced {
157197
}
158198

159199
// track-level PID selection
200+
auto trackPi = hfCandB0.template prong1_as<TracksPion>();
160201
if (usePionPid) {
161-
auto trackPi = hfCandB0.template prong1_as<TracksPion>();
162202
int pidTrackPi{TrackSelectorPID::Status::NotApplicable};
163203
if (usePionPid == 1) {
164204
pidTrackPi = selectorPion.statusTpcOrTof(trackPi);
@@ -168,13 +208,33 @@ struct HfCandidateSelectorB0ToDPiReduced {
168208
if (!hfHelper.selectionB0ToDPiPid(pidTrackPi, acceptPIDNotApplicable.value)) {
169209
// LOGF(info, "B0 candidate selection failed at PID selection");
170210
hfSelB0ToDPiCandidate(statusB0ToDPi);
211+
if (applyB0Ml) {
212+
hfMlB0ToDPiCandidate(outputMlNotPreselected);
213+
}
171214
continue;
172215
}
173216
SETBIT(statusB0ToDPi, SelectionStep::RecoPID); // RecoPID = 2 --> statusB0ToDPi = 7
174217
if (activateQA) {
175218
registry.fill(HIST("hSelections"), 2 + SelectionStep::RecoPID, ptCandB0);
176219
}
177220
}
221+
222+
if (applyB0Ml) {
223+
// B0 ML selections
224+
std::vector<float> inputFeatures = hfMlResponse.getInputFeatures<withDmesMl>(hfCandB0, trackPi);
225+
bool isSelectedMl = hfMlResponse.isSelectedMl(inputFeatures, ptCandB0, outputMl);
226+
hfMlB0ToDPiCandidate(outputMl[1]); // storing ML score for signal class
227+
228+
if (!isSelectedMl) {
229+
hfSelB0ToDPiCandidate(statusB0ToDPi);
230+
continue;
231+
}
232+
SETBIT(statusB0ToDPi, SelectionStep::RecoMl); // RecoML = 3 --> statusB0ToDPi = 15 if usePionPid, 11 otherwise
233+
if (activateQA) {
234+
registry.fill(HIST("hSelections"), 2 + SelectionStep::RecoMl, ptCandB0);
235+
}
236+
}
237+
178238
hfSelB0ToDPiCandidate(statusB0ToDPi);
179239
// LOGF(info, "B0 candidate selection passed all selections");
180240
}

0 commit comments

Comments
 (0)