Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions PWGEM/Dilepton/Core/DielectronCut.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

Check failure on line 11 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[doc/file]

Documentation for \file is missing, incorrect or misplaced.
//
// Class for dielectron selection
//
Expand All @@ -33,8 +33,8 @@
#include <utility>
#include <vector>

using namespace o2::aod::pwgem::dilepton::utils::emtrackutil;

Check failure on line 36 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[using-directive]

Do not put using directives at global scope in headers.
using namespace o2::aod::pwgem::dilepton::utils::pairutil;

Check failure on line 37 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[using-directive]

Do not put using directives at global scope in headers.

class DielectronCut : public TNamed
{
Expand Down Expand Up @@ -162,8 +162,8 @@
return true;
}

template <bool dont_require_pteta = false, bool isML = false, typename TTrack, typename TCollision = int>
bool IsSelectedTrack(TTrack const& track, TCollision const& collision = 0) const
template <bool dont_require_pteta = false, typename TTrack>
bool IsSelectedTrack(TTrack const& track) const
{
if (!track.hasITS()) {
return false;
Expand Down Expand Up @@ -252,36 +252,31 @@
}

// PID cuts
if constexpr (isML) {
if (!PassPIDML(track, collision)) {
if (track.hasITS() && !track.hasTPC() && !track.hasTRD() && !track.hasTOF()) { // ITSsa
float meanClusterSizeITS = track.meanClusterSizeITS() * std::cos(std::atan(track.tgl()));
if (meanClusterSizeITS < mMinMeanClusterSizeITS || mMaxMeanClusterSizeITS < meanClusterSizeITS) {
return false;
}
} else {
if (track.hasITS() && !track.hasTPC() && !track.hasTRD() && !track.hasTOF()) { // ITSsa
float meanClusterSizeITS = track.meanClusterSizeITS() * std::cos(std::atan(track.tgl()));
if (meanClusterSizeITS < mMinMeanClusterSizeITS || mMaxMeanClusterSizeITS < meanClusterSizeITS) {
return false;
}
} else { // not ITSsa
if (!PassPID(track)) {
return false;
}
} else { // not ITSsa
if (!PassPID(track)) {
return false;
}
}

return true;
}

template <typename TTrack, typename TCollision>
bool PassPIDML(TTrack const&, TCollision const&) const
template <typename TTrack>
bool PassPIDML(TTrack const& track) const
{
return false;
/*if (!PassTOFif(track)) { // Allows for pre-selection. But potentially dangerous if analyzers are not aware of it
return false;
}*/
// std::vector<float> inputFeatures = mPIDMlResponse->getInputFeatures(track, collision);
// float binningFeature = mPIDMlResponse->getBinningFeature(track, collision);
// return mPIDMlResponse->isSelectedMl(inputFeatures, binningFeature);
int pbin = lower_bound(mMLBins.begin(), mMLBins.end(), track.tpcInnerParam()) - mMLBins.begin() - 1;
if (pbin < 0) {
pbin = 0;
} else if (static_cast<int>(mMLBins.size()) - 2 < pbin) {

Check failure on line 275 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
pbin = static_cast<int>(mMLBins.size()) - 2;
}
// LOGF(info, "track.tpcInnerParam() = %f, pbin = %d, track.probElBDT() = %f, mMLCuts[pbin] = %f", track.tpcInnerParam(), pbin, track.probElBDT(), mMLCuts[pbin]);
return track.probElBDT() > mMLCuts[pbin];
}

template <typename T>
Expand All @@ -307,7 +302,7 @@
return PassTOFif(track);

case static_cast<int>(PIDSchemes::kPIDML):
return true; // don't use kPIDML here.
return PassPIDML(track);

case static_cast<int>(PIDSchemes::kTPChadrejORTOFreq_woTOFif):
return PassTPConlyhadrej(track) || PassTOFreq(track);
Expand Down Expand Up @@ -404,8 +399,8 @@
bool is_in_phi_range = track.phi() > mMinTrackPhi && track.phi() < mMaxTrackPhi;
return mRejectTrackPhi ? !is_in_phi_range : is_in_phi_range;
} else {
double minTrackPhiMirror = mMinTrackPhi + TMath::Pi();

Check failure on line 402 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[external-pi]

Use the PI constant (and its multiples and fractions) defined in o2::constants::math.
double maxTrackPhiMirror = mMaxTrackPhi + TMath::Pi();

Check failure on line 403 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[external-pi]

Use the PI constant (and its multiples and fractions) defined in o2::constants::math.
bool is_in_phi_range = (track.phi() > mMinTrackPhi && track.phi() < mMaxTrackPhi) || (track.phi() > minTrackPhiMirror && track.phi() < maxTrackPhiMirror);
return mRejectTrackPhi ? !is_in_phi_range : is_in_phi_range;
}
Expand Down Expand Up @@ -464,7 +459,7 @@

void SetTrackPtRange(float minPt = 0.f, float maxPt = 1e10f);
void SetTrackEtaRange(float minEta = -1e10f, float maxEta = 1e10f);
void SetTrackPhiRange(float minPhi = 0.f, float maxPhi = 2.f * M_PI, bool mirror = false, bool reject = false);

Check failure on line 462 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pi-multiple-fraction]

Use multiples/fractions of PI defined in o2::constants::math.

Check failure on line 462 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[external-pi]

Use the PI constant (and its multiples and fractions) defined in o2::constants::math.
void SetMinNClustersTPC(int minNClustersTPC);
void SetMinNCrossedRowsTPC(int minNCrossedRowsTPC);
void SetMinNCrossedRowsOverFindableClustersTPC(float minNCrossedRowsOverFindableClustersTPC);
Expand Down Expand Up @@ -517,6 +512,18 @@
mPIDMlResponse = mlResponse;
}

void SetMLThresholds(const std::vector<float> bins, const std::vector<float> cuts)
{
if (bins.size() != cuts.size() + 1) {
LOG(fatal) << "cuts.size() + 1 mutst be exactly the same as bins.size(). Check your bins and thresholds.";
}
mMLBins = bins;
mMLCuts = cuts;
// for (int i = 0; i < static_cast<int>(mMLBins.size()) - 1; i++) {
// printf("Dielectron cut: mMLBins[%d] = %3.2f, mMLBins[%d] = %3.2f, mMLCuts[%d] = %3.2f\n", i, mMLBins[i], i + 1, mMLBins[i + 1], i, mMLCuts[i]);
// }
}

// Getters
bool IsPhotonConversionSelected() const { return mSelectPC; }

Expand All @@ -541,7 +548,7 @@
// kinematic cuts
float mMinTrackPt{0.f}, mMaxTrackPt{1e10f}; // range in pT
float mMinTrackEta{-1e10f}, mMaxTrackEta{1e10f}; // range in eta
float mMinTrackPhi{0.f}, mMaxTrackPhi{2.f * M_PI}; // range in phi

Check failure on line 551 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[pi-multiple-fraction]

Use multiples/fractions of PI defined in o2::constants::math.

Check failure on line 551 in PWGEM/Dilepton/Core/DielectronCut.h

View workflow job for this annotation

GitHub Actions / O2 linter

[external-pi]

Use the PI constant (and its multiples and fractions) defined in o2::constants::math.
bool mMirrorTrackPhi{false}, mRejectTrackPhi{false}; // phi cut mirror by Pi, rejected/accepted

// track quality cuts
Expand Down Expand Up @@ -597,6 +604,8 @@
// float mMinP_ITSNsigmaPr{0.0}, mMaxP_ITSNsigmaPr{0.0};

o2::analysis::MlResponseDielectronSingleTrack<float>* mPIDMlResponse{nullptr};
std::vector<float> mMLBins{}; // binning for a feature variable. e.g. tpcInnerParam
std::vector<float> mMLCuts{}; // threshold for each bin. mMLCuts.size() must be mMLBins.size()-1.

ClassDef(DielectronCut, 1);
};
Expand Down
78 changes: 45 additions & 33 deletions PWGEM/Dilepton/Core/Dilepton.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,30 +713,42 @@ struct Dilepton {
// fDielectronCut.SetPRangeForITSNsigmaPr(dielectroncuts.cfg_min_p_ITSNsigmaPr, dielectroncuts.cfg_max_p_ITSNsigmaPr);

if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) { // please call this at the end of DefineDileptonCut
static constexpr int nClassesMl = 2;
const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
const std::vector<std::string> labelsClasses = {"Background", "Signal"};
const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
const std::vector<std::string> labelsBins(nBinsMl, "bin");
double cutsMlArr[nBinsMl][nClassesMl];
for (uint32_t i = 0; i < nBinsMl; i++) {
cutsMlArr[i][0] = 0.;
cutsMlArr[i][1] = dielectroncuts.cutsMl.value[i];
}
o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};

mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
if (dielectroncuts.loadModelsFromCCDB) {
ccdbApi.init(ccdburl);
mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
} else {
mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
}
mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);

fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
std::vector<float> binsML{};
binsML.reserve(dielectroncuts.binsMl.value.size());
for (size_t i = 0; i < dielectroncuts.binsMl.value.size(); i++) {
binsML.emplace_back(dielectroncuts.binsMl.value[i]);
}
std::vector<float> thresholdsML{};
thresholdsML.reserve(dielectroncuts.cutsMl.value.size());
for (size_t i = 0; i < dielectroncuts.cutsMl.value.size(); i++) {
thresholdsML.emplace_back(dielectroncuts.cutsMl.value[i]);
}
fDielectronCut.SetMLThresholds(binsML, thresholdsML);

// static constexpr int nClassesMl = 2;
// const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
// const std::vector<std::string> labelsClasses = {"Background", "Signal"};
// const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
// const std::vector<std::string> labelsBins(nBinsMl, "bin");
// double cutsMlArr[nBinsMl][nClassesMl];
// for (uint32_t i = 0; i < nBinsMl; i++) {
// cutsMlArr[i][0] = 0.;
// cutsMlArr[i][1] = dielectroncuts.cutsMl.value[i];
// }
// o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};

// mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
// if (dielectroncuts.loadModelsFromCCDB) {
// ccdbApi.init(ccdburl);
// mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
// } else {
// mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
// }
// mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
// mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
// mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);

// fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
} // end of PID ML
}

Expand Down Expand Up @@ -836,11 +848,11 @@ struct Dilepton {
if constexpr (ev_id == 0) {
if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron) {
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) {
if (!cut.template IsSelectedTrack<false, true>(t1, collision) || !cut.template IsSelectedTrack<false, true>(t2, collision)) {
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
return false;
}
} else { // cut-based
if (!cut.template IsSelectedTrack<false, false>(t1) || !cut.template IsSelectedTrack<false, false>(t2)) {
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
return false;
}
}
Expand Down Expand Up @@ -1377,16 +1389,16 @@ struct Dilepton {

} // end of DF

template <typename TCollision, typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
bool isPairOK(TCollision const& collision, TTrack1 const& t1, TTrack2 const& t2, TCut const& cut, TAllTracks const& tracks)
template <typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
bool isPairOK(TTrack1 const& t1, TTrack2 const& t2, TCut const& cut, TAllTracks const& tracks)
{
if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron) {
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) {
if (!cut.template IsSelectedTrack<false, true>(t1, collision) || !cut.template IsSelectedTrack<false, true>(t2, collision)) {
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
return false;
}
} else { // cut-based
if (!cut.template IsSelectedTrack<false, false>(t1) || !cut.template IsSelectedTrack<false, false>(t2)) {
if (!cut.template IsSelectedTrack<false>(t1) || !cut.template IsSelectedTrack<false>(t2)) {
return false;
}
}
Expand Down Expand Up @@ -1470,17 +1482,17 @@ struct Dilepton {
auto negTracks_per_coll = negTracks.sliceByCached(perCollision, collision.globalIndex(), cache);

for (const auto& [pos, neg] : combinations(CombinationsFullIndexPolicy(posTracks_per_coll, negTracks_per_coll))) { // ULS
if (isPairOK(collision, pos, neg, cut, tracks)) {
if (isPairOK(pos, neg, cut, tracks)) {
passed_pairIds.emplace_back(std::make_pair(pos.globalIndex(), neg.globalIndex()));
}
}
for (const auto& [pos1, pos2] : combinations(CombinationsStrictlyUpperIndexPolicy(posTracks_per_coll, posTracks_per_coll))) { // LS++
if (isPairOK(collision, pos1, pos2, cut, tracks)) {
if (isPairOK(pos1, pos2, cut, tracks)) {
passed_pairIds.emplace_back(std::make_pair(pos1.globalIndex(), pos2.globalIndex()));
}
}
for (const auto& [neg1, neg2] : combinations(CombinationsStrictlyUpperIndexPolicy(negTracks_per_coll, negTracks_per_coll))) { // LS--
if (isPairOK(collision, neg1, neg2, cut, tracks)) {
if (isPairOK(neg1, neg2, cut, tracks)) {
passed_pairIds.emplace_back(std::make_pair(neg1.globalIndex(), neg2.globalIndex()));
}
}
Expand Down
Loading
Loading