Skip to content

Commit d1699c9

Browse files
author
Luigi Dello Stritto
committed
Saving ML score in LcpK0s tree creator
1 parent be3b919 commit d1699c9

File tree

3 files changed

+122
-54
lines changed

3 files changed

+122
-54
lines changed

PWGHF/DataModel/CandidateSelectionTables.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,12 @@ DECLARE_SOA_TABLE(HfSelJpsi, "AOD", "HFSELJPSI", //!
231231
namespace hf_sel_candidate_lc_to_k0s_p
232232
{
233233
DECLARE_SOA_COLUMN(IsSelLcToK0sP, isSelLcToK0sP, int);
234+
DECLARE_SOA_COLUMN(MlProbLcToK0sP, mlProbLcToK0sP, std::vector<float>); //!
234235
} // namespace hf_sel_candidate_lc_to_k0s_p
235-
236236
DECLARE_SOA_TABLE(HfSelLcToK0sP, "AOD", "HFSELLCK0SP", //!
237237
hf_sel_candidate_lc_to_k0s_p::IsSelLcToK0sP);
238+
DECLARE_SOA_TABLE(HfMlLcToK0sP, "AOD", "HFMLLcK0sP", //!
239+
hf_sel_candidate_lc_to_k0s_p::MlProbLcToK0sP);
238240

239241
namespace hf_sel_candidate_b0
240242
{

PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using namespace o2::framework;
5454

5555
struct HfCandidateSelectorLcToK0sP {
5656
Produces<aod::HfSelLcToK0sP> hfSelLcToK0sPCandidate;
57+
Produces<aod::HfMlLcToK0sP> hfMlLcToK0sPCandidate;
5758

5859
Configurable<double> ptCandMin{"ptCandMin", 0., "Lower bound of candidate pT"};
5960
Configurable<double> ptCandMax{"ptCandMax", 50., "Upper bound of candidate pT"};
@@ -95,6 +96,7 @@ struct HfCandidateSelectorLcToK0sP {
9596
TrackSelectorPr selectorProtonHighP;
9697

9798
o2::analysis::HfMlResponseLcToK0sP<float> hfMlResponse;
99+
std::vector<float> outputMl = {};
98100

99101
o2::ccdb::CcdbApi ccdbApi;
100102

@@ -239,12 +241,11 @@ struct HfCandidateSelectorLcToK0sP {
239241
}
240242

241243
template <typename T, typename U>
242-
bool selectionMl(const T& hfCandCascade, const U& bach)
244+
bool selectionMl(const T& hfCandCascade, const U& bach, std::vector<float>& outputMl)
243245
{
244246

245247
auto ptCand = hfCandCascade.pt();
246248
std::vector<float> inputFeatures = hfMlResponse.getInputFeatures(hfCandCascade, bach);
247-
std::vector<float> outputMl = {};
248249

249250
bool isSelectedMl = hfMlResponse.isSelectedMl(inputFeatures, ptCand, outputMl);
250251

@@ -265,26 +266,39 @@ struct HfCandidateSelectorLcToK0sP {
265266
const auto& bach = candidate.prong0_as<TracksSel>(); // bachelor track
266267

267268
statusLc = 0;
269+
outputMl.clear();
268270

269271
// implement filter bit 4 cut - should be done before this task at the track selection level
270272
// need to add special cuts (additional cuts on decay length and d0 norm)
271273
if (!selectionTopol(candidate)) {
272274
hfSelLcToK0sPCandidate(statusLc);
275+
if (applyMl) {
276+
hfMlLcToK0sPCandidate(outputMl);
277+
}
273278
continue;
274279
}
275280

276281
if (!selectionStandardPID(bach)) {
277282
hfSelLcToK0sPCandidate(statusLc);
283+
if (applyMl) {
284+
hfMlLcToK0sPCandidate(outputMl);
285+
}
278286
continue;
279287
}
280288

281-
if (applyMl && !selectionMl(candidate, bach)) {
282-
hfSelLcToK0sPCandidate(statusLc);
283-
continue;
289+
bool isSelectedMlLcToK0sP = true;
290+
if (applyMl) {
291+
isSelectedMlLcToK0sP = false;
292+
isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl);
293+
hfMlLcToK0sPCandidate(outputMl);
294+
295+
if (!isSelectedMlLcToK0sP) {
296+
hfSelLcToK0sPCandidate(statusLc);
297+
continue;
298+
}
284299
}
285300

286301
statusLc = 1;
287-
288302
hfSelLcToK0sPCandidate(statusLc);
289303
}
290304
}
@@ -299,24 +313,37 @@ struct HfCandidateSelectorLcToK0sP {
299313
const auto& bach = candidate.prong0_as<TracksSelBayes>(); // bachelor track
300314

301315
statusLc = 0;
316+
outputMl.clear();
302317

303318
if (!selectionTopol(candidate)) {
304319
hfSelLcToK0sPCandidate(statusLc);
320+
if (applyMl) {
321+
hfMlLcToK0sPCandidate(outputMl);
322+
}
305323
continue;
306324
}
307325

308326
if (!selectionBayesPID(bach)) {
309327
hfSelLcToK0sPCandidate(statusLc);
328+
if (applyMl) {
329+
hfMlLcToK0sPCandidate(outputMl);
330+
}
310331
continue;
311332
}
312333

313-
if (applyMl && !selectionMl(candidate, bach)) {
314-
hfSelLcToK0sPCandidate(statusLc);
315-
continue;
334+
bool isSelectedMlLcToK0sP = true;
335+
if (applyMl) {
336+
isSelectedMlLcToK0sP = false;
337+
isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl);
338+
hfMlLcToK0sPCandidate(outputMl);
339+
340+
if (!isSelectedMlLcToK0sP) {
341+
hfSelLcToK0sPCandidate(statusLc);
342+
continue;
343+
}
316344
}
317345

318346
statusLc = 1;
319-
320347
hfSelLcToK0sPCandidate(statusLc);
321348
}
322349
}

PWGHF/TableProducer/treeCreatorLcToK0sP.cxx

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
#include <Framework/InitContext.h>
3434
#include <Framework/runDataProcessing.h>
3535

36+
#include <algorithm>
3637
#include <cstdint>
3738
#include <cstdlib>
39+
#include <vector>
3840

3941
using namespace o2;
4042
using namespace o2::framework;
@@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float);
8486
DECLARE_SOA_COLUMN(FlagMc, flagMc, int8_t);
8587
DECLARE_SOA_COLUMN(OriginMcRec, originMcRec, int8_t);
8688
DECLARE_SOA_COLUMN(OriginMcGen, originMcGen, int8_t);
89+
DECLARE_SOA_COLUMN(MlScoreFirstClass, mlScoreFirstClass, float);
90+
DECLARE_SOA_COLUMN(MlScoreSecondClass, mlScoreSecondClass, float);
91+
DECLARE_SOA_COLUMN(MlScoreThirdClass, mlScoreThirdClass, float);
8792
// Events
8893
DECLARE_SOA_COLUMN(IsEventReject, isEventReject, int);
8994
DECLARE_SOA_COLUMN(RunNumber, runNumber, int);
@@ -126,7 +131,10 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE",
126131
full::Y,
127132
full::E,
128133
full::FlagMc,
129-
full::OriginMcRec);
134+
full::OriginMcRec,
135+
full::MlScoreFirstClass,
136+
full::MlScoreSecondClass,
137+
full::MlScoreThirdClass);
130138

131139
DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
132140
collision::BCId,
@@ -196,7 +204,10 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
196204
full::Y,
197205
full::E,
198206
full::FlagMc,
199-
full::OriginMcRec);
207+
full::OriginMcRec,
208+
full::MlScoreFirstClass,
209+
full::MlScoreSecondClass,
210+
full::MlScoreThirdClass);
200211

201212
DECLARE_SOA_TABLE(HfCandCascFullEs, "AOD", "HFCANDCASCFULLE",
202213
collision::BCId,
@@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP {
228239
Configurable<float> ptMaxForDownSample{"ptMaxForDownSample", 24., "Maximum pt for the application of the downsampling factor"};
229240
Configurable<bool> fillOnlySignal{"fillOnlySignal", false, "Flag to fill derived tables with signal for ML trainings"};
230241
Configurable<bool> fillOnlyBackground{"fillOnlyBackground", false, "Flag to fill derived tables with background for ML trainings"};
242+
Configurable<bool> applyMl{"applyMl", false, "Whether ML was used in candidateSelectorLc"};
243+
244+
constexpr static float UndefValueFloat = -999.f;
231245

232246
HfHelper hfHelper;
233247

234248
Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1;
235249
using TracksWPid = soa::Join<aod::Tracks, aod::TracksPidPr>;
236250
using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandCascade, aod::HfCandCascadeMcRec, aod::HfSelLcToK0sP>>;
237251

238-
Partition<SelectedCandidatesMc> recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t(0);
239-
Partition<SelectedCandidatesMc> recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t(0);
240-
241252
void init(InitContext const&)
242253
{
243254
}
244255

256+
/// \brief function to get ML score values for the current candidate and assign them to input parameters
257+
/// \param candidate candidate instance
258+
/// \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate
259+
/// \param mlScoreFirstClass ML score for belonging to the first class
260+
/// \param mlScoreSecondClass ML score for belonging to the second class
261+
/// \param mlScoreThirdClass ML score for belonging to the third class
262+
void assignMlScores(aod::HfMlLcToK0sP::iterator const& candidateMlScore, float& mlScoreFirstClass, float& mlScoreSecondClass, float& mlScoreThirdClass)
263+
{
264+
std::vector<float> mlScores;
265+
std::copy(candidateMlScore.mlProbLcToK0sP().begin(), candidateMlScore.mlProbLcToK0sP().end(), std::back_inserter(mlScores));
266+
267+
constexpr int IndexFirstClass{0};
268+
constexpr int IndexSecondClass{1};
269+
constexpr int IndexThirdClass{2};
270+
if (mlScores.size() == 0) {
271+
return; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty
272+
}
273+
mlScoreFirstClass = mlScores.at(IndexFirstClass);
274+
mlScoreSecondClass = mlScores.at(IndexSecondClass);
275+
if (mlScores.size() > IndexThirdClass) {
276+
mlScoreThirdClass = mlScores.at(IndexThirdClass);
277+
}
278+
}
279+
245280
template <typename T, typename U>
246-
void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec)
281+
void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const& candidateMlScore)
247282
{
283+
284+
float mlScoreFirstClass{UndefValueFloat};
285+
float mlScoreSecondClass{UndefValueFloat};
286+
float mlScoreThirdClass{UndefValueFloat};
287+
288+
if (applyMl) {
289+
assignMlScores(candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass);
290+
}
291+
248292
if (fillCandidateLiteTable) {
249293
rowCandidateLite(
250294
candidate.chi2PCA(),
@@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP {
283327
hfHelper.yLc(candidate),
284328
hfHelper.eLc(candidate),
285329
flagMc,
286-
originMcRec);
330+
originMcRec,
331+
mlScoreFirstClass,
332+
mlScoreSecondClass,
333+
mlScoreThirdClass);
287334
} else {
288335
rowCandidateFull(
289336
bach.collision().bcId(),
@@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP {
353400
hfHelper.yLc(candidate),
354401
hfHelper.eLc(candidate),
355402
flagMc,
356-
originMcRec);
403+
originMcRec,
404+
mlScoreFirstClass,
405+
mlScoreSecondClass,
406+
mlScoreThirdClass);
357407
}
358408
}
359409
template <typename T>
@@ -370,6 +420,7 @@ struct HfTreeCreatorLcToK0sP {
370420
void processMc(aod::Collisions const& collisions,
371421
aod::McCollisions const&,
372422
SelectedCandidatesMc const& candidates,
423+
aod::HfMlLcToK0sP const& candidateMlScores,
373424
soa::Join<aod::McParticles, aod::HfCandCascadeMcGen> const& particles,
374425
TracksWPid const&)
375426
{
@@ -380,42 +431,25 @@ struct HfTreeCreatorLcToK0sP {
380431
fillEvent(collision);
381432
}
382433

383-
if (fillOnlySignal) {
384-
if (fillCandidateLiteTable) {
385-
rowCandidateLite.reserve(recSig.size());
386-
} else {
387-
rowCandidateFull.reserve(recSig.size());
388-
}
389-
for (const auto& candidate : recSig) {
390-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
391-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
392-
}
393-
} else if (fillOnlyBackground) {
394-
if (fillCandidateLiteTable) {
395-
rowCandidateLite.reserve(recBkg.size());
396-
} else {
397-
rowCandidateFull.reserve(recBkg.size());
398-
}
399-
for (const auto& candidate : recBkg) {
400-
if (downSampleBkgFactor < 1.) {
401-
float pseudoRndm = candidate.ptProng0() * 1000. - static_cast<int64_t>(candidate.ptProng0() * 1000);
402-
if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
403-
continue;
404-
}
405-
}
406-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
407-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
408-
}
434+
if (fillCandidateLiteTable) {
435+
rowCandidateLite.reserve(candidates.size());
409436
} else {
410-
// Filling candidate properties
411-
if (fillCandidateLiteTable) {
412-
rowCandidateLite.reserve(candidates.size());
437+
rowCandidateFull.reserve(candidates.size());
438+
}
439+
440+
int iCand{0};
441+
for (const auto& candidate : candidates) {
442+
auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
443+
++iCand;
444+
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
445+
const int flag = candidate.flagMcMatchRec();
446+
447+
if (fillOnlySignal && flag != 0) {
448+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
449+
} else if (fillOnlyBackground && flag == 0) {
450+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
413451
} else {
414-
rowCandidateFull.reserve(candidates.size());
415-
}
416-
for (const auto& candidate : candidates) {
417-
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
418-
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
452+
fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
419453
}
420454
}
421455

@@ -439,6 +473,7 @@ struct HfTreeCreatorLcToK0sP {
439473

440474
void processData(aod::Collisions const& collisions,
441475
soa::Join<aod::HfCandCascade, aod::HfSelLcToK0sP> const& candidates,
476+
aod::HfMlLcToK0sP const& candidateMlScores,
442477
TracksWPid const&)
443478
{
444479

@@ -454,11 +489,15 @@ struct HfTreeCreatorLcToK0sP {
454489
} else {
455490
rowCandidateFull.reserve(candidates.size());
456491
}
492+
493+
int iCand{0};
457494
for (const auto& candidate : candidates) {
495+
auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
496+
++iCand;
458497
auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
459498
double pseudoRndm = bach.pt() * 1000. - static_cast<int16_t>(bach.pt() * 1000);
460499
if (candidate.isSelLcToK0sP() >= 1 && pseudoRndm < downSampleBkgFactor) {
461-
fillCandidate(candidate, bach, 0, 0);
500+
fillCandidate(candidate, bach, 0, 0, candidateMlScore);
462501
}
463502
}
464503
}

0 commit comments

Comments
 (0)