Skip to content

Commit eb0c9f3

Browse files
jikim1290alibuild
andauthored
[PWGCF] adding ML selection in the correlations task (#9968)
Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent 20f1e56 commit eb0c9f3

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

PWGCF/Tasks/correlations.cxx

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ struct CorrelationTask {
9696
O2_DEFINE_CONFIGURABLE(cfgMassAxis, int, 0, "Use invariant mass axis (0 = OFF, 1 = ON)")
9797
O2_DEFINE_CONFIGURABLE(cfgMcTriggerPDGs, std::vector<int>, {}, "MC PDG codes to use exclusively as trigger particles and exclude from associated particles. Empty = no selection.")
9898

99+
O2_DEFINE_CONFIGURABLE(cfgPtDepMLbkg, std::vector<float>, {}, "pT interval for ML training")
100+
O2_DEFINE_CONFIGURABLE(cfgPtCentDepMLbkgSel, std::vector<float>, {}, "Bkg ML selection")
101+
99102
ConfigurableAxis axisVertex{"axisVertex", {7, -7, 7}, "vertex axis for histograms"};
100103
ConfigurableAxis axisDeltaPhi{"axisDeltaPhi", {72, -PIHalf, PIHalf * 3}, "delta phi axis for histograms"};
101104
ConfigurableAxis axisDeltaEta{"axisDeltaEta", {40, -2, 2}, "delta eta axis for histograms"};
@@ -324,6 +327,8 @@ struct CorrelationTask {
324327
using HasProng0Id = decltype(std::declval<T&>().cfTrackProng0Id());
325328
template <class T>
326329
using HasProng1Id = decltype(std::declval<T&>().cfTrackProng1Id());
330+
template <class T>
331+
using HasMlProbD0 = decltype(std::declval<T&>().mlProbD0());
327332

328333
template <CorrelationContainer::CFStep step, typename TTarget, typename TTracks1, typename TTracks2>
329334
void fillCorrelations(TTarget target, TTracks1& tracks1, TTracks2& tracks2, float multiplicity, float posZ, int magField, float eventWeight)
@@ -371,6 +376,18 @@ struct CorrelationTask {
371376
}
372377
}
373378

379+
if constexpr (std::experimental::is_detected<HasMlProbD0, typename TTracks1::iterator>::value) {
380+
if (doprocessSame2Prong2ProngML || doprocessMixed2Prong2ProngML) {
381+
auto it = std::lower_bound(cfgPtDepMLbkg->begin(), cfgPtDepMLbkg->end(), track1.pt());
382+
int idx = std::distance(cfgPtDepMLbkg->begin(), it) - 1;
383+
if (track1.decay() == 0 && track1.mlProbD0()[0] > cfgPtCentDepMLbkgSel->at(idx)) {
384+
continue;
385+
} else if (track1.decay() == 1 && track1.mlProbD0bar()[0] > cfgPtCentDepMLbkgSel->at(idx)) {
386+
continue;
387+
}
388+
}
389+
}
390+
374391
if (cfgMassAxis) {
375392
if constexpr (std::experimental::is_detected<HasInvMass, typename TTracks1::iterator>::value)
376393
target->getTriggerHist()->Fill(step, track1.pt(), multiplicity, posZ, track1.invMass(), triggerWeight);
@@ -484,6 +501,18 @@ struct CorrelationTask {
484501

485502
float deltaPhi = RecoDecay::constrainAngle(track1.phi() - track2.phi(), -o2::constants::math::PIHalf);
486503

504+
if constexpr (std::experimental::is_detected<HasMlProbD0, typename TTracks2::iterator>::value) {
505+
if (doprocessSame2Prong2ProngML || doprocessMixed2Prong2ProngML) {
506+
auto it = std::lower_bound(cfgPtDepMLbkg->begin(), cfgPtDepMLbkg->end(), track2.pt());
507+
int idx = std::distance(cfgPtDepMLbkg->begin(), it) - 1;
508+
if (track2.decay() == 0 && track2.mlProbD0()[0] > cfgPtCentDepMLbkgSel->at(idx)) {
509+
continue;
510+
} else if (track2.decay() == 1 && track2.mlProbD0bar()[0] > cfgPtCentDepMLbkgSel->at(idx)) {
511+
continue;
512+
}
513+
}
514+
}
515+
487516
// last param is the weight
488517
if (cfgMassAxis && (doprocessSame2Prong2Prong || doprocessMixed2Prong2Prong) && !(doprocessSame2ProngDerived || doprocessMixed2ProngDerived)) {
489518
if constexpr (std::experimental::is_detected<HasInvMass, typename TTracks1::iterator>::value && std::experimental::is_detected<HasInvMass, typename TTracks2::iterator>::value)
@@ -644,6 +673,30 @@ struct CorrelationTask {
644673
}
645674
PROCESS_SWITCH(CorrelationTask, processSame2Prong2Prong, "Process same event on derived data", false);
646675

676+
void processSame2Prong2ProngML(DerivedCollisions::iterator const& collision, soa::Filtered<soa::Join<aod::CF2ProngTracks, aod::CF2ProngTrackmls>> const& p2tracks)
677+
{
678+
BinningTypeDerived configurableBinningDerived{{axisVertex, axisMultiplicity}, true}; // true is for 'ignore overflows' (true by default). Underflows and overflows will have bin -1.
679+
if (cfgVerbosity > 0) {
680+
LOGF(info, "processSame2ProngDerived: 2-prong candidates: %d | Vertex: %.1f | Multiplicity/Centrality: %.1f", p2tracks.size(), collision.posZ(), collision.multiplicity());
681+
}
682+
loadEfficiency(collision.timestamp());
683+
684+
const auto multiplicity = collision.multiplicity();
685+
686+
int bin = configurableBinningDerived.getBin({collision.posZ(), collision.multiplicity()});
687+
registry.fill(HIST("eventcount_same"), bin);
688+
fillQA(collision, multiplicity, p2tracks, p2tracks);
689+
690+
same->fillEvent(multiplicity, CorrelationContainer::kCFStepReconstructed);
691+
fillCorrelations<CorrelationContainer::kCFStepReconstructed>(same, p2tracks, p2tracks, multiplicity, collision.posZ(), 0, 1.0f);
692+
693+
if (cfg.mEfficiencyAssociated || cfg.mEfficiencyTrigger) {
694+
same->fillEvent(multiplicity, CorrelationContainer::kCFStepCorrected);
695+
fillCorrelations<CorrelationContainer::kCFStepCorrected>(same, p2tracks, p2tracks, multiplicity, collision.posZ(), 0, 1.0f);
696+
}
697+
}
698+
PROCESS_SWITCH(CorrelationTask, processSame2Prong2ProngML, "Process same event on derived data", false);
699+
647700
using BinningTypeAOD = ColumnBinningPolicy<aod::collision::PosZ, aod::cent::CentRun2V0M>;
648701
void processMixedAOD(AodCollisions const& collisions, AodTracks const& tracks, aod::BCsWithTimestamps const&)
649702
{
@@ -807,6 +860,45 @@ struct CorrelationTask {
807860
}
808861
PROCESS_SWITCH(CorrelationTask, processMixed2Prong2Prong, "Process mixed events on derived data", false);
809862

863+
void processMixed2Prong2ProngML(DerivedCollisions const& collisions, soa::Filtered<soa::Join<aod::CF2ProngTracks, aod::CF2ProngTrackmls>> const& p2tracks)
864+
{
865+
BinningTypeDerived configurableBinningDerived{{axisVertex, axisMultiplicity}, true}; // true is for 'ignore overflows' (true by default). Underflows and overflows will have bin -1.
866+
// Strictly upper categorised collisions, for cfgNoMixedEvents combinations per bin, skipping those in entry -1
867+
auto tracksTuple = std::make_tuple(p2tracks);
868+
SameKindPair<DerivedCollisions, soa::Filtered<soa::Join<aod::CF2ProngTracks, aod::CF2ProngTrackmls>>, BinningTypeDerived> pairs{configurableBinningDerived, cfgNoMixedEvents, -1, collisions, tracksTuple, &cache}; // -1 is the number of the bin to skip
869+
870+
for (auto it = pairs.begin(); it != pairs.end(); it++) {
871+
auto& [collision1, tracks1, collision2, tracks2] = *it;
872+
int bin = configurableBinningDerived.getBin({collision1.posZ(), collision1.multiplicity()});
873+
float eventWeight = 1.0f / it.currentWindowNeighbours();
874+
int field = 0;
875+
if (cfgTwoTrackCut > 0) {
876+
field = getMagneticField(collision1.timestamp());
877+
}
878+
879+
if (cfgVerbosity > 0) {
880+
LOGF(info, "processMixedDerived: Mixed collisions bin: %d pair: [%d, %d] %d (%.3f, %.3f), %d (%.3f, %.3f)", bin, it.isNewWindow(), it.currentWindowNeighbours(), collision1.globalIndex(), collision1.posZ(), collision1.multiplicity(), collision2.globalIndex(), collision2.posZ(), collision2.multiplicity());
881+
}
882+
883+
if (it.isNewWindow()) {
884+
loadEfficiency(collision1.timestamp());
885+
mixed->fillEvent(collision1.multiplicity(), CorrelationContainer::kCFStepReconstructed);
886+
}
887+
888+
// LOGF(info, "Tracks: %d and %d entries", tracks1.size(), tracks2.size());
889+
890+
registry.fill(HIST("eventcount_mixed"), bin);
891+
fillCorrelations<CorrelationContainer::kCFStepReconstructed>(mixed, tracks1, tracks2, collision1.multiplicity(), collision1.posZ(), field, eventWeight);
892+
if (cfg.mEfficiencyAssociated || cfg.mEfficiencyTrigger) {
893+
if (it.isNewWindow()) {
894+
mixed->fillEvent(collision1.multiplicity(), CorrelationContainer::kCFStepCorrected);
895+
}
896+
fillCorrelations<CorrelationContainer::kCFStepCorrected>(mixed, tracks1, tracks2, collision1.multiplicity(), collision1.posZ(), field, eventWeight);
897+
}
898+
}
899+
}
900+
PROCESS_SWITCH(CorrelationTask, processMixed2Prong2ProngML, "Process mixed events on derived data", false);
901+
810902
// Version with combinations
811903
/*void processWithCombinations(soa::Join<aod::Collisions, aod::CentRun2V0Ms>::iterator const& collision, aod::BCsWithTimestamps const&, soa::Filtered<aod::Tracks> const& tracks)
812904
{

0 commit comments

Comments
 (0)