Skip to content

Commit df88af4

Browse files
DelloStrittoLuigi Dello Strittovkucera
authored
[PWGHF] Save ML scores in LcToK0sP tree creator (#12331)
Co-authored-by: Luigi Dello Stritto <ldellost@alicecerno2.cern.ch> Co-authored-by: Vít Kučera <vit.kucera@cern.ch>
1 parent 27fde20 commit df88af4

File tree

3 files changed

+135
-61
lines changed

3 files changed

+135
-61
lines changed

PWGHF/DataModel/CandidateSelectionTables.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,12 @@ DECLARE_SOA_TABLE(HfSelJpsi, "AOD", "HFSELJPSI", //!
231231
namespace hf_sel_candidate_lc_to_k0s_p
232232
{
233233
DECLARE_SOA_COLUMN(IsSelLcToK0sP, isSelLcToK0sP, int);
234+
DECLARE_SOA_COLUMN(MlProbLcToK0sP, mlProbLcToK0sP, std::vector<float>); //!
234235
} // namespace hf_sel_candidate_lc_to_k0s_p
235-
236236
DECLARE_SOA_TABLE(HfSelLcToK0sP, "AOD", "HFSELLCK0SP", //!
237237
hf_sel_candidate_lc_to_k0s_p::IsSelLcToK0sP);
238+
DECLARE_SOA_TABLE(HfMlLcToK0sP, "AOD", "HFMLLcK0sP", //!
239+
hf_sel_candidate_lc_to_k0s_p::MlProbLcToK0sP);
238240

239241
namespace hf_sel_candidate_b0
240242
{

PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using namespace o2::framework;
5454

5555
struct HfCandidateSelectorLcToK0sP {
5656
Produces<aod::HfSelLcToK0sP> hfSelLcToK0sPCandidate;
57+
Produces<aod::HfMlLcToK0sP> hfMlLcToK0sPCandidate;
5758

5859
Configurable<double> ptCandMin{"ptCandMin", 0., "Lower bound of candidate pT"};
5960
Configurable<double> ptCandMax{"ptCandMax", 50., "Upper bound of candidate pT"};
@@ -95,6 +96,7 @@ struct HfCandidateSelectorLcToK0sP {
9596
TrackSelectorPr selectorProtonHighP;
9697

9798
o2::analysis::HfMlResponseLcToK0sP<float> hfMlResponse;
99+
std::vector<float> outputMl = {};
98100

99101
o2::ccdb::CcdbApi ccdbApi;
100102

@@ -239,12 +241,11 @@ struct HfCandidateSelectorLcToK0sP {
239241
}
240242

241243
template <typename T, typename U>
242-
bool selectionMl(const T& hfCandCascade, const U& bach)
244+
bool selectionMl(const T& hfCandCascade, const U& bach, std::vector<float>& outputMl)
243245
{
244246

245247
auto ptCand = hfCandCascade.pt();
246248
std::vector<float> inputFeatures = hfMlResponse.getInputFeatures(hfCandCascade, bach);
247-
std::vector<float> outputMl = {};
248249

249250
bool isSelectedMl = hfMlResponse.isSelectedMl(inputFeatures, ptCand, outputMl);
250251

@@ -265,26 +266,37 @@ struct HfCandidateSelectorLcToK0sP {
265266
const auto& bach = candidate.prong0_as<TracksSel>(); // bachelor track
266267

267268
statusLc = 0;
269+
outputMl.clear();
268270

269271
// implement filter bit 4 cut - should be done before this task at the track selection level
270272
// need to add special cuts (additional cuts on decay length and d0 norm)
271273
if (!selectionTopol(candidate)) {
272274
hfSelLcToK0sPCandidate(statusLc);
275+
if (applyMl) {
276+
hfMlLcToK0sPCandidate(outputMl);
277+
}
273278
continue;
274279
}
275280

276281
if (!selectionStandardPID(bach)) {
277282
hfSelLcToK0sPCandidate(statusLc);
283+
if (applyMl) {
284+
hfMlLcToK0sPCandidate(outputMl);
285+
}
278286
continue;
279287
}
280288

281-
if (applyMl && !selectionMl(candidate, bach)) {
282-
hfSelLcToK0sPCandidate(statusLc);
283-
continue;
289+
if (applyMl) {
290+
bool isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl);
291+
hfMlLcToK0sPCandidate(outputMl);
292+
293+
if (!isSelectedMlLcToK0sP) {
294+
hfSelLcToK0sPCandidate(statusLc);
295+
continue;
296+
}
284297
}
285298

286299
statusLc = 1;
287-
288300
hfSelLcToK0sPCandidate(statusLc);
289301
}
290302
}
@@ -299,24 +311,35 @@ struct HfCandidateSelectorLcToK0sP {
299311
const auto& bach = candidate.prong0_as<TracksSelBayes>(); // bachelor track
300312

301313
statusLc = 0;
314+
outputMl.clear();
302315

303316
if (!selectionTopol(candidate)) {
304317
hfSelLcToK0sPCandidate(statusLc);
318+
if (applyMl) {
319+
hfMlLcToK0sPCandidate(outputMl);
320+
}
305321
continue;
306322
}
307323

308324
if (!selectionBayesPID(bach)) {
309325
hfSelLcToK0sPCandidate(statusLc);
326+
if (applyMl) {
327+
hfMlLcToK0sPCandidate(outputMl);
328+
}
310329
continue;
311330
}
312331

313-
if (applyMl && !selectionMl(candidate, bach)) {
314-
hfSelLcToK0sPCandidate(statusLc);
315-
continue;
332+
if (applyMl) {
333+
bool isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl);
334+
hfMlLcToK0sPCandidate(outputMl);
335+
336+
if (!isSelectedMlLcToK0sP) {
337+
hfSelLcToK0sPCandidate(statusLc);
338+
continue;
339+
}
316340
}
317341

318342
statusLc = 1;
319-
320343
hfSelLcToK0sPCandidate(statusLc);
321344
}
322345
}

PWGHF/TableProducer/treeCreatorLcToK0sP.cxx

Lines changed: 99 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
#include <Framework/InitContext.h>
3434
#include <Framework/runDataProcessing.h>
3535

36+
#include <algorithm>
3637
#include <cstdint>
3738
#include <cstdlib>
39+
#include <vector>
3840

3941
using namespace o2;
4042
using namespace o2::framework;
@@ -68,8 +70,8 @@ DECLARE_SOA_COLUMN(DecayLength, decayLength, float);
6870
DECLARE_SOA_COLUMN(DecayLengthXY, decayLengthXY, float);
6971
DECLARE_SOA_COLUMN(DecayLengthNormalised, decayLengthNormalised, float);
7072
DECLARE_SOA_COLUMN(DecayLengthXYNormalised, decayLengthXYNormalised, float);
71-
DECLARE_SOA_COLUMN(CPA, cpa, float);
72-
DECLARE_SOA_COLUMN(CPAXY, cpaXY, float);
73+
DECLARE_SOA_COLUMN(Cpa, cpa, float);
74+
DECLARE_SOA_COLUMN(CpaXY, cpaXY, float);
7375
DECLARE_SOA_COLUMN(Ct, ct, float);
7476
DECLARE_SOA_COLUMN(PtV0Pos, ptV0Pos, float);
7577
DECLARE_SOA_COLUMN(PtV0Neg, ptV0Neg, float);
@@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float);
8486
DECLARE_SOA_COLUMN(FlagMc, flagMc, int8_t);
8587
DECLARE_SOA_COLUMN(OriginMcRec, originMcRec, int8_t);
8688
DECLARE_SOA_COLUMN(OriginMcGen, originMcGen, int8_t);
89+
DECLARE_SOA_COLUMN(MlScoreFirstClass, mlScoreFirstClass, float);
90+
DECLARE_SOA_COLUMN(MlScoreSecondClass, mlScoreSecondClass, float);
91+
DECLARE_SOA_COLUMN(MlScoreThirdClass, mlScoreThirdClass, float);
8792
// Events
8893
DECLARE_SOA_COLUMN(IsEventReject, isEventReject, int);
8994
DECLARE_SOA_COLUMN(RunNumber, runNumber, int);
@@ -118,15 +123,18 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE",
118123
full::NSigmaTOFPr0,
119124
full::M,
120125
full::Pt,
121-
full::CPA,
122-
full::CPAXY,
126+
full::Cpa,
127+
full::CpaXY,
123128
full::Ct,
124129
full::Eta,
125130
full::Phi,
126131
full::Y,
127132
full::E,
128133
full::FlagMc,
129-
full::OriginMcRec);
134+
full::OriginMcRec,
135+
full::MlScoreFirstClass,
136+
full::MlScoreSecondClass,
137+
full::MlScoreThirdClass);
130138

131139
DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
132140
collision::BCId,
@@ -188,15 +196,18 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
188196
full::M,
189197
full::Pt,
190198
full::P,
191-
full::CPA,
192-
full::CPAXY,
199+
full::Cpa,
200+
full::CpaXY,
193201
full::Ct,
194202
full::Eta,
195203
full::Phi,
196204
full::Y,
197205
full::E,
198206
full::FlagMc,
199-
full::OriginMcRec);
207+
full::OriginMcRec,
208+
full::MlScoreFirstClass,
209+
full::MlScoreSecondClass,
210+
full::MlScoreThirdClass);
200211

201212
DECLARE_SOA_TABLE(HfCandCascFullEs, "AOD", "HFCANDCASCFULLE",
202213
collision::BCId,
@@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP {
228239
Configurable<float> ptMaxForDownSample{"ptMaxForDownSample", 24., "Maximum pt for the application of the downsampling factor"};
229240
Configurable<bool> fillOnlySignal{"fillOnlySignal", false, "Flag to fill derived tables with signal for ML trainings"};
230241
Configurable<bool> fillOnlyBackground{"fillOnlyBackground", false, "Flag to fill derived tables with background for ML trainings"};
242+
Configurable<bool> applyMl{"applyMl", false, "Whether ML was used in candidateSelectorLc"};
243+
244+
constexpr static float UndefValueFloat = -999.f;
231245

232246
HfHelper hfHelper;
233247

234-
Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1;
235248
using TracksWPid = soa::Join<aod::Tracks, aod::TracksPidPr>;
236249
using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandCascade, aod::HfCandCascadeMcRec, aod::HfSelLcToK0sP>>;
237-
238-
Partition<SelectedCandidatesMc> recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t(0);
239-
Partition<SelectedCandidatesMc> recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t(0);
250+
Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1;
240251

241252
void init(InitContext const&)
242253
{
243254
}
244255

256+
/// \brief function to get ML score values for the current candidate and assign them to input parameters
257+
/// \param candidate candidate instance
258+
/// \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate
259+
/// \param mlScoreFirstClass ML score for belonging to the first class
260+
/// \param mlScoreSecondClass ML score for belonging to the second class
261+
/// \param mlScoreThirdClass ML score for belonging to the third class
262+
void assignMlScores(aod::HfMlLcToK0sP::iterator const& candidateMlScore, float& mlScoreFirstClass, float& mlScoreSecondClass, float& mlScoreThirdClass)
263+
{
264+
std::vector<float> mlScores;
265+
std::copy(candidateMlScore.mlProbLcToK0sP().begin(), candidateMlScore.mlProbLcToK0sP().end(), std::back_inserter(mlScores));
266+
267+
constexpr int IndexFirstClass{0};
268+
constexpr int IndexSecondClass{1};
269+
constexpr int IndexThirdClass{2};
270+
if (mlScores.size() == 0) {
271+
return; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty
272+
}
273+
mlScoreFirstClass = mlScores.at(IndexFirstClass);
274+
mlScoreSecondClass = mlScores.at(IndexSecondClass);
275+
if (mlScores.size() > IndexThirdClass) {
276+
mlScoreThirdClass = mlScores.at(IndexThirdClass);
277+
}
278+
}
279+
245280
template <typename T, typename U>
246-
void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec)
281+
void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const& candidateMlScore)
247282
{
283+
284+
float mlScoreFirstClass{UndefValueFloat};
285+
float mlScoreSecondClass{UndefValueFloat};
286+
float mlScoreThirdClass{UndefValueFloat};
287+
288+
if (applyMl) {
289+
assignMlScores(candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass);
290+
}
291+
248292
if (fillCandidateLiteTable) {
249293
rowCandidateLite(
250294
candidate.chi2PCA(),
@@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP {
283327
hfHelper.yLc(candidate),
284328
hfHelper.eLc(candidate),
285329
flagMc,
286-
originMcRec);
330+
originMcRec,
331+
mlScoreFirstClass,
332+
mlScoreSecondClass,
333+
mlScoreThirdClass);
287334
} else {
288335
rowCandidateFull(
289336
bach.collision().bcId(),
@@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP {
353400
hfHelper.yLc(candidate),
354401
hfHelper.eLc(candidate),
355402
flagMc,
356-
originMcRec);
403+
originMcRec,
404+
mlScoreFirstClass,
405+
mlScoreSecondClass,
406+
mlScoreThirdClass);
357407
}
358408
}
359409
template <typename T>
@@ -370,52 +420,41 @@ struct HfTreeCreatorLcToK0sP {
370420
void processMc(aod::Collisions const& collisions,
371421
aod::McCollisions const&,
372422
SelectedCandidatesMc const& candidates,
423+
aod::HfMlLcToK0sP const& candidateMlScores,
373424
soa::Join<aod::McParticles, aod::HfCandCascadeMcGen> const& particles,
374425
TracksWPid const&)
375426
{
376427

428+
if (applyMl && candidateMlScores.size() == 0) {
429+
LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables.";
430+
return;
431+
}
432+
377433
// Filling event properties
378434
rowCandidateFullEvents.reserve(collisions.size());
379435
for (const auto& collision : collisions) {
380436
fillEvent(collision);
381437
}
382438

383-
if (fillOnlySignal) {
384-
if (fillCandidateLiteTable) {
385-
rowCandidateLite.reserve(recSig.size());
386-
} else {
387-
rowCandidateFull.reserve(recSig.size());
388-
}
389-
for (const auto& candidate : recSig) {
390-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
391-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
392-
}
393-
} else if (fillOnlyBackground) {
394-
if (fillCandidateLiteTable) {
395-
rowCandidateLite.reserve(recBkg.size());
396-
} else {
397-
rowCandidateFull.reserve(recBkg.size());
398-
}
399-
for (const auto& candidate : recBkg) {
400-
if (downSampleBkgFactor < 1.) {
401-
float pseudoRndm = candidate.ptProng0() * 1000. - static_cast<int64_t>(candidate.ptProng0() * 1000);
402-
if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
403-
continue;
404-
}
405-
}
406-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
407-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
408-
}
439+
if (fillCandidateLiteTable) {
440+
rowCandidateLite.reserve(candidates.size());
409441
} else {
410-
// Filling candidate properties
411-
if (fillCandidateLiteTable) {
412-
rowCandidateLite.reserve(candidates.size());
442+
rowCandidateFull.reserve(candidates.size());
443+
}
444+
445+
int iCand{0};
446+
for (const auto& candidate : candidates) {
447+
auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
448+
++iCand;
449+
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
450+
const int flag = candidate.flagMcMatchRec();
451+
452+
if (fillOnlySignal && flag != 0) {
453+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
454+
} else if (fillOnlyBackground && flag == 0) {
455+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
413456
} else {
414-
rowCandidateFull.reserve(candidates.size());
415-
}
416-
for (const auto& candidate : candidates) {
417-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
418-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
457+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
419458
}
420459
}
421460

@@ -439,9 +478,15 @@ struct HfTreeCreatorLcToK0sP {
439478

440479
void processData(aod::Collisions const& collisions,
441480
soa::Join<aod::HfCandCascade, aod::HfSelLcToK0sP> const& candidates,
481+
aod::HfMlLcToK0sP const& candidateMlScores,
442482
TracksWPid const&)
443483
{
444484

485+
if (applyMl && candidateMlScores.size() == 0) {
486+
LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables.";
487+
return;
488+
}
489+
445490
// Filling event properties
446491
rowCandidateFullEvents.reserve(collisions.size());
447492
for (const auto& collision : collisions) {
@@ -454,11 +499,15 @@ struct HfTreeCreatorLcToK0sP {
454499
} else {
455500
rowCandidateFull.reserve(candidates.size());
456501
}
502+
503+
int iCand{0};
457504
for (const auto& candidate : candidates) {
505+
auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
506+
++iCand;
458507
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
459508
double pseudoRndm = bach.pt() * 1000. - static_cast<int16_t>(bach.pt() * 1000);
460509
if (candidate.isSelLcToK0sP() >= 1 && pseudoRndm < downSampleBkgFactor) {
461-
fillCandidate(candidate, bach, 0, 0);
510+
fillCandidate(candidate, bach, 0, 0, candidateMlScore);
462511
}
463512
}
464513
}

0 commit comments

Comments
 (0)