Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PWGDQ/Core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore
AnalysisCompositeCut.cxx
MCProng.cxx
MCSignal.cxx
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle)
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore)

o2physics_target_root_dictionary(PWGDQCore
HEADERS AnalysisCut.h
Expand Down
166 changes: 161 additions & 5 deletions PWGDQ/Core/CutsLibrary.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
// Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no
//
#include "PWGDQ/Core/CutsLibrary.h"
#include <RtypesCore.h>
#include <TF1.h>
#include <vector>
#include <string>
#include <iostream>

#include "AnalysisCompositeCut.h"
#include "VarManager.h"

#include <TF1.h>

#include <RtypesCore.h>

#include <iostream>

Check failure on line 23 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[include-iostream]

Do not include iostream. Use O2 logging instead.
#include <set>
#include <string>
#include <vector>

using std::cout;
using std::endl;

Expand Down Expand Up @@ -1163,7 +1168,7 @@
return cut;
}

for (int iCut = 0; iCut < 10; iCut++) {

Check failure on line 1171 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("jpsiEleSel%d_ionut", iCut))) {
cut->AddCut(GetAnalysisCut("kineJpsiEle_ionut"));
cut->AddCut(GetAnalysisCut("dcaCut1_ionut"));
Expand Down Expand Up @@ -1463,7 +1468,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 1471 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("dalitzSelected%d", i))) {
cut->AddCut(GetAnalysisCut(Form("dalitzLeg%d", i)));
return cut;
Expand Down Expand Up @@ -1965,7 +1970,7 @@
return cut;
}

for (unsigned int i = 0; i < 30; i++) {

Check failure on line 1973 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("ElSelCutVar%s%i", vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
cut->AddCut(GetAnalysisCut(Form("lmeeCutVarTrackCuts%i", i)));
Expand Down Expand Up @@ -2735,7 +2740,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 2743 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("lmee%s_pp502TeV_PID%s_UsePrefilter%d", vecTypetrackWithPID.at(jcase).Data(), vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut(Form("notDalitzLeg%d", i)));
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
Expand Down Expand Up @@ -2792,7 +2797,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 2800 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("lmee%s_eNSigmaRun3%s_UsePrefilter%d", vecTypetrackWithPID.at(jcase).Data(), vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut(Form("notDalitzLeg%d", i)));
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
Expand Down Expand Up @@ -4476,7 +4481,7 @@
cut->AddCut(VarManager::kITSncls, 6.5, 7.5);
cut->AddCut(VarManager::kTPCnclsCR, 80.0, 161.);
cut->AddCut(VarManager::kTPCncls, 90.0, 170.);
} else if (icase == 2) {

Check failure on line 4484 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
cut->AddCut(VarManager::kIsSPDfirst, 0.5, 1.5);
cut->AddCut(VarManager::kITSchi2, 0.0, 5.0);
cut->AddCut(VarManager::kITSncls, 4.5, 7.5);
Expand Down Expand Up @@ -7099,3 +7104,154 @@

return retCut;
}

//________________________________________________________________________________________________
o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json)
{
LOG(info) << "========================================== interpreting JSON for analysis cuts";
LOG(info) << "JSON string: " << json;

rapidjson::Document document;
rapidjson::ParseResult ok = document.Parse(json);
if (!ok) {
LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")";
return {}; // empty variant
}

for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) {
const auto& obj = it->value;

if (!obj.HasMember("type")) {
LOG(fatal) << "Missing type (Binary/MultiClass)";
return {};
}

TString typeStr = obj["type"].GetString();
// int nClasses = (typeStr == "MultiClass") ? 3 : 1;

std::vector<std::string> namesInputFeatures;
if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) {
for (auto& feature : obj["inputFeatures"].GetArray()) {

Check failure on line 7134 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
namesInputFeatures.emplace_back(feature.GetString());
}
}

std::vector<std::string> onnxFileNames;
if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) {
for (const auto& model : obj["modelFiles"].GetArray()) {
onnxFileNames.emplace_back(model.GetString());
}
}

// Cut storage
std::vector<std::pair<double, double>> ptBins;
std::vector<std::vector<double>> cutsMl;
std::vector<int> cutDirs;
bool cutDirsFilled = false;

for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) {
TString key = member->name.GetString();
if (!key.Contains("AddCut"))
continue;

const auto& cut = member->value;

if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) {
LOG(fatal) << "Missing pTMin/pTMax in ML cut";
return {};
}

double pTMin = cut["pTMin"].GetDouble();
double pTMax = cut["pTMax"].GetDouble();
ptBins.emplace_back(pTMin, pTMax);

std::vector<double> binCuts;
bool exclude = false;

for (auto& sub : cut.GetObject()) {

Check failure on line 7171 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
TString subKey = sub.name.GetString();
if (!subKey.Contains("AddMLCut"))
continue;

const auto& mlcut = sub.value;
// const char* var = mlcut["var"].GetString();
double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5;
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;

binCuts.push_back(cutVal);

if (!cutDirsFilled) {
cutDirs.push_back(exclude ? 1 : 0);
cutDirsFilled = true;
}
}

cutsMl.push_back(binCuts);
}

// bin edges
std::vector<double> binsPt;
if (!ptBins.empty()) {
std::set<double> binEdges;
for (auto& b : ptBins)

Check failure on line 7196 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[const-ref-in-for-loop]

Use constant references for non-modified iterators in range-based for loops.
binEdges.insert(b.first);
binEdges.insert(ptBins.back().second);
binsPt = std::vector<double>(binEdges.begin(), binEdges.end());
} else {
LOG(fatal) << "No pT bins found in ML cuts";
return {};
}

std::vector<std::string> labelsPt, labelsClass;
for (size_t i = 0; i < cutsMl.size(); ++i) {
labelsPt.push_back(Form("pT%.1f", binsPt[i]));
}
for (size_t j = 0; j < cutsMl[0].size(); ++j) {
labelsClass.push_back(Form("cls%d", static_cast<int>(j)));
}

// Binary
if (typeStr == "Binary") {
dqmlcuts::BinaryBdtScoreConfig binaryCfg;
binaryCfg.inputFeatures = namesInputFeatures;
binaryCfg.onnxFiles = onnxFileNames;
binaryCfg.binsPt = binsPt;
binaryCfg.cutDirs = cutDirs;
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);

return binaryCfg;

// MultiClass
} else if (typeStr == "MultiClass") {
dqmlcuts::MultiClassBdtScoreConfig multiCfg;
multiCfg.inputFeatures = namesInputFeatures;
multiCfg.onnxFiles = onnxFileNames;
multiCfg.binsPt = binsPt;
multiCfg.cutDirs = cutDirs;
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);

return multiCfg;
}

LOG(fatal) << "Unsupported classification type: " << typeStr;
return {};
}

return {};
}

o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
const std::vector<std::string>& labelsPt,
const std::vector<std::string>& labelsClass)
{
const size_t nRows = cuts.size();
const size_t nCols = cuts.empty() ? 0 : cuts[0].size();
std::vector<double> flat;

for (const auto& row : cuts) {
flat.insert(flat.end(), row.begin(), row.end());
}

o2::framework::Array2D<double> arr(flat.data(), nRows, nCols);
return o2::framework::LabeledArray<double>(arr, labelsPt, labelsClass);
}
26 changes: 26 additions & 0 deletions PWGDQ/Core/CutsLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ bool ValidateJSONAnalysisCompositeCut(T cut);
template <typename T>
AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName);
} // namespace dqcuts
namespace dqmlcuts
{
struct BinaryBdtScoreConfig {
std::vector<std::string> inputFeatures;
std::vector<std::string> onnxFiles;
std::vector<double> binsPt;
o2::framework::LabeledArray<double> cutsMl;
std::vector<int> cutDirs;
};

struct MultiClassBdtScoreConfig {
std::vector<std::string> inputFeatures;
std::vector<std::string> onnxFiles;
std::vector<double> binsPt;
o2::framework::LabeledArray<double> cutsMl;
std::vector<int> cutDirs;
};

using BdtScoreConfig = std::variant<BinaryBdtScoreConfig, MultiClassBdtScoreConfig>;

BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json);

o2::framework::LabeledArray<double> makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
const std::vector<std::string>& labelsPt,
const std::vector<std::string>& labelsClass);
} // namespace dqmlcuts
} // namespace o2::aod

AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName);
Expand Down
Loading
Loading