Skip to content

Commit 90060ee

Browse files
dsekihatalibuild
andauthored
[PWGEM/Dilepton] 1st version of PID ML (#12680)
Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent c7cc561 commit 90060ee

File tree

11 files changed

+420
-297
lines changed

11 files changed

+420
-297
lines changed

PWGEM/Dilepton/Core/DielectronCut.h

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ class DielectronCut : public TNamed
162162
return true;
163163
}
164164

165-
template <bool dont_require_pteta = false, bool isML = false, typename TTrack, typename TCollision = int>
166-
bool IsSelectedTrack(TTrack const& track, TCollision const& collision = 0) const
165+
template <bool dont_require_pteta = false, typename TTrack>
166+
bool IsSelectedTrack(TTrack const& track) const
167167
{
168168
if (!track.hasITS()) {
169169
return false;
@@ -252,36 +252,31 @@ class DielectronCut : public TNamed
252252
}
253253

254254
// PID cuts
255-
if constexpr (isML) {
256-
if (!PassPIDML(track, collision)) {
255+
if (track.hasITS() && !track.hasTPC() && !track.hasTRD() && !track.hasTOF()) { // ITSsa
256+
float meanClusterSizeITS = track.meanClusterSizeITS() * std::cos(std::atan(track.tgl()));
257+
if (meanClusterSizeITS < mMinMeanClusterSizeITS || mMaxMeanClusterSizeITS < meanClusterSizeITS) {
257258
return false;
258259
}
259-
} else {
260-
if (track.hasITS() && !track.hasTPC() && !track.hasTRD() && !track.hasTOF()) { // ITSsa
261-
float meanClusterSizeITS = track.meanClusterSizeITS() * std::cos(std::atan(track.tgl()));
262-
if (meanClusterSizeITS < mMinMeanClusterSizeITS || mMaxMeanClusterSizeITS < meanClusterSizeITS) {
263-
return false;
264-
}
265-
} else { // not ITSsa
266-
if (!PassPID(track)) {
267-
return false;
268-
}
260+
} else { // not ITSsa
261+
if (!PassPID(track)) {
262+
return false;
269263
}
270264
}
271265

272266
return true;
273267
}
274268

275-
template <typename TTrack, typename TCollision>
276-
bool PassPIDML(TTrack const&, TCollision const&) const
269+
template <typename TTrack>
270+
bool PassPIDML(TTrack const& track) const
277271
{
278-
return false;
279-
/*if (!PassTOFif(track)) { // Allows for pre-selection. But potentially dangerous if analyzers are not aware of it
280-
return false;
281-
}*/
282-
// std::vector<float> inputFeatures = mPIDMlResponse->getInputFeatures(track, collision);
283-
// float binningFeature = mPIDMlResponse->getBinningFeature(track, collision);
284-
// return mPIDMlResponse->isSelectedMl(inputFeatures, binningFeature);
272+
int pbin = lower_bound(mMLBins.begin(), mMLBins.end(), track.tpcInnerParam()) - mMLBins.begin() - 1;
273+
if (pbin < 0) {
274+
pbin = 0;
275+
} else if (static_cast<int>(mMLBins.size()) - 2 < pbin) {
276+
pbin = static_cast<int>(mMLBins.size()) - 2;
277+
}
278+
// LOGF(info, "track.tpcInnerParam() = %f, pbin = %d, track.probElBDT() = %f, mMLCuts[pbin] = %f", track.tpcInnerParam(), pbin, track.probElBDT(), mMLCuts[pbin]);
279+
return track.probElBDT() > mMLCuts[pbin];
285280
}
286281

287282
template <typename T>
@@ -307,7 +302,7 @@ class DielectronCut : public TNamed
307302
return PassTOFif(track);
308303

309304
case static_cast<int>(PIDSchemes::kPIDML):
310-
return true; // don't use kPIDML here.
305+
return PassPIDML(track);
311306

312307
case static_cast<int>(PIDSchemes::kTPChadrejORTOFreq_woTOFif):
313308
return PassTPConlyhadrej(track) || PassTOFreq(track);
@@ -517,6 +512,18 @@ class DielectronCut : public TNamed
517512
mPIDMlResponse = mlResponse;
518513
}
519514

515+
void SetMLThresholds(const std::vector<float> bins, const std::vector<float> cuts)
516+
{
517+
if (bins.size() != cuts.size() + 1) {
518+
LOG(fatal) << "cuts.size() + 1 mutst be exactly the same as bins.size(). Check your bins and thresholds.";
519+
}
520+
mMLBins = bins;
521+
mMLCuts = cuts;
522+
// for (int i = 0; i < static_cast<int>(mMLBins.size()) - 1; i++) {
523+
// printf("Dielectron cut: mMLBins[%d] = %3.2f, mMLBins[%d] = %3.2f, mMLCuts[%d] = %3.2f\n", i, mMLBins[i], i + 1, mMLBins[i + 1], i, mMLCuts[i]);
524+
// }
525+
}
526+
520527
// Getters
521528
bool IsPhotonConversionSelected() const { return mSelectPC; }
522529

@@ -597,6 +604,8 @@ class DielectronCut : public TNamed
597604
// float mMinP_ITSNsigmaPr{0.0}, mMaxP_ITSNsigmaPr{0.0};
598605

599606
o2::analysis::MlResponseDielectronSingleTrack<float>* mPIDMlResponse{nullptr};
607+
std::vector<float> mMLBins{}; // binning for a feature variable. e.g. tpcInnerParam
608+
std::vector<float> mMLCuts{}; // threshold for each bin. mMLCuts.size() must be mMLBins.size()-1.
600609

601610
ClassDef(DielectronCut, 1);
602611
};

PWGEM/Dilepton/Core/Dilepton.h

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -713,30 +713,42 @@ struct Dilepton {
713713
// fDielectronCut.SetPRangeForITSNsigmaPr(dielectroncuts.cfg_min_p_ITSNsigmaPr, dielectroncuts.cfg_max_p_ITSNsigmaPr);
714714

715715
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) { // please call this at the end of DefineDileptonCut
716-
static constexpr int nClassesMl = 2;
717-
const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
718-
const std::vector<std::string> labelsClasses = {"Background", "Signal"};
719-
const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
720-
const std::vector<std::string> labelsBins(nBinsMl, "bin");
721-
double cutsMlArr[nBinsMl][nClassesMl];
722-
for (uint32_t i = 0; i < nBinsMl; i++) {
723-
cutsMlArr[i][0] = 0.;
724-
cutsMlArr[i][1] = dielectroncuts.cutsMl.value[i];
725-
}
726-
o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
727-
728-
mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
729-
if (dielectroncuts.loadModelsFromCCDB) {
730-
ccdbApi.init(ccdburl);
731-
mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
732-
} else {
733-
mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
734-
}
735-
mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
736-
mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
737-
mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);
738-
739-
fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
716+
std::vector<float> binsML{};
717+
binsML.reserve(dielectroncuts.binsMl.value.size());
718+
for (size_t i = 0; i < dielectroncuts.binsMl.value.size(); i++) {
719+
binsML.emplace_back(dielectroncuts.binsMl.value[i]);
720+
}
721+
std::vector<float> thresholdsML{};
722+
thresholdsML.reserve(dielectroncuts.cutsMl.value.size());
723+
for (size_t i = 0; i < dielectroncuts.cutsMl.value.size(); i++) {
724+
thresholdsML.emplace_back(dielectroncuts.cutsMl.value[i]);
725+
}
726+
fDielectronCut.SetMLThresholds(binsML, thresholdsML);
727+
728+
// static constexpr int nClassesMl = 2;
729+
// const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
730+
// const std::vector<std::string> labelsClasses = {"Background", "Signal"};
731+
// const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
732+
// const std::vector<std::string> labelsBins(nBinsMl, "bin");
733+
// double cutsMlArr[nBinsMl][nClassesMl];
734+
// for (uint32_t i = 0; i < nBinsMl; i++) {
735+
// cutsMlArr[i][0] = 0.;
736+
// cutsMlArr[i][1] = dielectroncuts.cutsMl.value[i];
737+
// }
738+
// o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
739+
740+
// mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
741+
// if (dielectroncuts.loadModelsFromCCDB) {
742+
// ccdbApi.init(ccdburl);
743+
// mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
744+
// } else {
745+
// mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
746+
// }
747+
// mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
748+
// mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
749+
// mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);
750+
751+
// fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
740752
} // end of PID ML
741753
}
742754

@@ -836,11 +848,11 @@ struct Dilepton {
836848
if constexpr (ev_id == 0) {
837849
if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron) {
838850
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) {
839-
if (!cut.template IsSelectedTrack<false, true>(t1, collision) || !cut.template IsSelectedTrack<false, true>(t2, collision)) {
851+
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
840852
return false;
841853
}
842854
} else { // cut-based
843-
if (!cut.template IsSelectedTrack<false, false>(t1) || !cut.template IsSelectedTrack<false, false>(t2)) {
855+
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
844856
return false;
845857
}
846858
}
@@ -1377,16 +1389,16 @@ struct Dilepton {
13771389

13781390
} // end of DF
13791391

1380-
template <typename TCollision, typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
1381-
bool isPairOK(TCollision const& collision, TTrack1 const& t1, TTrack2 const& t2, TCut const& cut, TAllTracks const& tracks)
1392+
template <typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
1393+
bool isPairOK(TTrack1 const& t1, TTrack2 const& t2, TCut const& cut, TAllTracks const& tracks)
13821394
{
13831395
if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron) {
13841396
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) {
1385-
if (!cut.template IsSelectedTrack<false, true>(t1, collision) || !cut.template IsSelectedTrack<false, true>(t2, collision)) {
1397+
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
13861398
return false;
13871399
}
13881400
} else { // cut-based
1389-
if (!cut.template IsSelectedTrack<false, false>(t1) || !cut.template IsSelectedTrack<false, false>(t2)) {
1401+
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
13901402
return false;
13911403
}
13921404
}
@@ -1470,17 +1482,17 @@ struct Dilepton {
14701482
auto negTracks_per_coll = negTracks.sliceByCached(perCollision, collision.globalIndex(), cache);
14711483

14721484
for (const auto& [pos, neg] : combinations(CombinationsFullIndexPolicy(posTracks_per_coll, negTracks_per_coll))) { // ULS
1473-
if (isPairOK(collision, pos, neg, cut, tracks)) {
1485+
if (isPairOK(pos, neg, cut, tracks)) {
14741486
passed_pairIds.emplace_back(std::make_pair(pos.globalIndex(), neg.globalIndex()));
14751487
}
14761488
}
14771489
for (const auto& [pos1, pos2] : combinations(CombinationsStrictlyUpperIndexPolicy(posTracks_per_coll, posTracks_per_coll))) { // LS++
1478-
if (isPairOK(collision, pos1, pos2, cut, tracks)) {
1490+
if (isPairOK(pos1, pos2, cut, tracks)) {
14791491
passed_pairIds.emplace_back(std::make_pair(pos1.globalIndex(), pos2.globalIndex()));
14801492
}
14811493
}
14821494
for (const auto& [neg1, neg2] : combinations(CombinationsStrictlyUpperIndexPolicy(negTracks_per_coll, negTracks_per_coll))) { // LS--
1483-
if (isPairOK(collision, neg1, neg2, cut, tracks)) {
1495+
if (isPairOK(neg1, neg2, cut, tracks)) {
14841496
passed_pairIds.emplace_back(std::make_pair(neg1.globalIndex(), neg2.globalIndex()));
14851497
}
14861498
}

0 commit comments

Comments
 (0)