Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PWGDQ/Core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore
AnalysisCompositeCut.cxx
MCProng.cxx
MCSignal.cxx
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle)
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore)

o2physics_target_root_dictionary(PWGDQCore
HEADERS AnalysisCut.h
Expand Down
242 changes: 237 additions & 5 deletions PWGDQ/Core/CutsLibrary.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
// Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no
//
#include "PWGDQ/Core/CutsLibrary.h"
#include <RtypesCore.h>
#include <TF1.h>
#include <vector>
#include <string>
#include <iostream>

#include "AnalysisCompositeCut.h"
#include "VarManager.h"

#include <TF1.h>

#include <RtypesCore.h>

#include <iostream>

Check failure on line 23 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[include-iostream]

Do not include iostream. Use O2 logging instead.
#include <set>
#include <string>
#include <vector>

using std::cout;
using std::endl;

Expand Down Expand Up @@ -1164,7 +1169,7 @@
return cut;
}

for (int iCut = 0; iCut < 10; iCut++) {

Check failure on line 1172 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("jpsiEleSel%d_ionut", iCut))) {
cut->AddCut(GetAnalysisCut("kineJpsiEle_ionut"));
cut->AddCut(GetAnalysisCut("dcaCut1_ionut"));
Expand Down Expand Up @@ -1464,7 +1469,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 1472 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("dalitzSelected%d", i))) {
cut->AddCut(GetAnalysisCut(Form("dalitzLeg%d", i)));
return cut;
Expand Down Expand Up @@ -1966,7 +1971,7 @@
return cut;
}

for (unsigned int i = 0; i < 30; i++) {

Check failure on line 1974 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("ElSelCutVar%s%i", vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
cut->AddCut(GetAnalysisCut(Form("lmeeCutVarTrackCuts%i", i)));
Expand Down Expand Up @@ -2736,7 +2741,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 2744 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("lmee%s_pp502TeV_PID%s_UsePrefilter%d", vecTypetrackWithPID.at(jcase).Data(), vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut(Form("notDalitzLeg%d", i)));
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
Expand Down Expand Up @@ -2793,7 +2798,7 @@
return cut;
}

for (int i = 1; i <= 8; i++) {

Check failure on line 2801 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
if (!nameStr.compare(Form("lmee%s_eNSigmaRun3%s_UsePrefilter%d", vecTypetrackWithPID.at(jcase).Data(), vecPIDcase.at(icase).Data(), i))) {
cut->AddCut(GetAnalysisCut(Form("notDalitzLeg%d", i)));
cut->AddCut(GetAnalysisCut("lmeeStandardKine"));
Expand Down Expand Up @@ -4477,7 +4482,7 @@
cut->AddCut(VarManager::kITSncls, 6.5, 7.5);
cut->AddCut(VarManager::kTPCnclsCR, 80.0, 161.);
cut->AddCut(VarManager::kTPCncls, 90.0, 170.);
} else if (icase == 2) {

Check failure on line 4485 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[magic-number]

Avoid magic numbers in expressions. Assign the value to a clearly named variable or constant.
cut->AddCut(VarManager::kIsSPDfirst, 0.5, 1.5);
cut->AddCut(VarManager::kITSchi2, 0.0, 5.0);
cut->AddCut(VarManager::kITSncls, 4.5, 7.5);
Expand Down Expand Up @@ -7100,3 +7105,230 @@

return retCut;
}

//________________________________________________________________________________________________
o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json)
{
LOG(info) << "========================================== interpreting JSON for ML analysis cuts";
if (!json) {
LOG(fatal) << "JSON config string is null!";
return {};
}
LOG(info) << "JSON string: " << json;

rapidjson::Document document;

// Check that the json is parsed correctly
rapidjson::ParseResult ok = document.Parse(json);
if (!ok) {
LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")";
return {};
}

for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) {
const auto& obj = it->value;

// Classification type
if (!obj.HasMember("type")) {
LOG(fatal) << "Missing type (Binary/MultiClass)";
return {};
}
TString typeStr = obj["type"].GetString();
if (typeStr != "Binary" && typeStr != "MultiClass") {
LOG(fatal) << "Unsupported classification type: " << typeStr;
return {};
}

// Input features
if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) {
LOG(fatal) << "Missing inputFeatures member or array";
return {};
}
std::vector<std::string> namesInputFeatures;
for (const auto& feature : obj["inputFeatures"].GetArray()) {
namesInputFeatures.emplace_back(feature.GetString());
LOG(debug) << "Input features: " << feature.GetString();
}

// Model files
if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) {
LOG(fatal) << "Missing modelFiles member or array";
return {};
}
std::vector<std::string> onnxFileNames;
for (const auto& model : obj["modelFiles"].GetArray()) {
onnxFileNames.emplace_back(model.GetString());
LOG(debug) << "Model Files: " << model.GetString() << " ";
}

// Centrality estimation type
if (!obj.HasMember("cent") || !obj["cent"].IsString()) {
LOG(fatal) << "Missing cent member";
return {};
}
std::string cent = obj["cent"].GetString();
LOG(debug) << "Centrality type: " << cent;
if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") {
LOG(fatal) << "Unsupported centrality type: " << cent;
return {};
}

// Cut storage
std::vector<std::pair<double, double>> centBins;
std::vector<std::pair<double, double>> ptBins;
std::vector<std::vector<double>> cutsMl;
std::vector<int> cutDirs;
std::vector<std::string> labelsFlatBin;
bool cutDirsFilled = false;

for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) {
TString centKey = centMember->name.GetString();
if (!centKey.Contains("AddCentCut"))
continue;

const auto& centCut = centMember->value;

// Centrality info
if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) {
LOG(fatal) << "Missing centMin/centMax in " << centKey;
return {};
}
double centMin = centCut["centMin"].GetDouble();
double centMax = centCut["centMax"].GetDouble();

for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) {
TString ptKey = ptMember->name.GetString();
if (!ptKey.Contains("AddPtCut"))
continue;

const auto& ptCut = ptMember->value;

// Pt info
if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) {
LOG(fatal) << "Missing pTMin/pTMax in " << ptKey;
return {};
}

double ptMin = ptCut["pTMin"].GetDouble();
double ptMax = ptCut["pTMax"].GetDouble();

std::vector<double> binCuts;
bool exclude = false;

for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) {
TString mlKey = mlMember->name.GetString();
if (!mlKey.Contains("AddMLCut"))
continue;

const auto& mlcut = mlMember->value;

if (!mlcut.HasMember("cut")) {
LOG(fatal) << "Missing cut (score) in " << mlKey;
return {};
}

double cutVal = mlcut["cut"].GetDouble();
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;

binCuts.push_back(cutVal);

if (!cutDirsFilled) {
cutDirs.push_back(exclude ? 0 : 1);
}
}

if (!cutDirsFilled) {
cutDirsFilled = true;
}

centBins.emplace_back(centMin, centMax);
ptBins.emplace_back(ptMin, ptMax);
cutsMl.push_back(binCuts);
labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax));
LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: [";
for (size_t i = 0; i < binCuts.size(); ++i) {
std::cout << binCuts[i];

Check failure on line 7250 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[logging]

Use O2 logging (LOG, LOGF, LOGP).
if (i != binCuts.size() - 1)
std::cout << ", ";

Check failure on line 7252 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[logging]

Use O2 logging (LOG, LOGF, LOGP).
}
std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl;

Check failure on line 7254 in PWGDQ/Core/CutsLibrary.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[logging]

Use O2 logging (LOG, LOGF, LOGP).
}
}

if (cutDirs.size() != cutsMl[0].size()) {
LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size()
<< ", cutsMl[0].size() = " << cutDirs.size();
return {};
}

std::vector<std::string> labelsClass;
for (size_t j = 0; j < cutsMl[0].size(); ++j) {
labelsClass.push_back(Form("score class %d", static_cast<int>(j)));
}

size_t nFlatBins = cutsMl.size();
std::vector<double> binsMl(nFlatBins + 1);
std::iota(binsMl.begin(), binsMl.end(), 0);

// Binary
if (typeStr == "Binary") {
dqmlcuts::BinaryBdtScoreConfig binaryCfg;
binaryCfg.inputFeatures = namesInputFeatures;
binaryCfg.onnxFiles = onnxFileNames;
binaryCfg.binsCent = centBins;
binaryCfg.binsPt = ptBins;
binaryCfg.binsMl = binsMl;
binaryCfg.cutDirs = cutDirs;
binaryCfg.centType = cent;
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);

return binaryCfg;

// MultiClass
} else if (typeStr == "MultiClass") {
dqmlcuts::MultiClassBdtScoreConfig multiCfg;
multiCfg.inputFeatures = namesInputFeatures;
multiCfg.onnxFiles = onnxFileNames;
multiCfg.binsCent = centBins;
multiCfg.binsPt = ptBins;
multiCfg.binsMl = binsMl;
multiCfg.cutDirs = cutDirs;
multiCfg.centType = cent;
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);

return multiCfg;
}
}

return {};
}

o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
const std::vector<std::string>& labelsflatBin,
const std::vector<std::string>& labelsClass)
{
const size_t nRows = cuts.size();
const size_t nCols = cuts.empty() ? 0 : cuts[0].size();
std::vector<double> flat;

for (const auto& row : cuts) {
flat.insert(flat.end(), row.begin(), row.end());
}

o2::framework::Array2D<double> arr(flat.data(), nRows, nCols);
return o2::framework::LabeledArray<double>(arr, labelsflatBin, labelsClass);
}

int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt,
const std::vector<std::pair<double, double>>& binsCent,
const std::vector<std::pair<double, double>>& binsPt)
{
LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt;
for (size_t i = 0; i < binsCent.size(); ++i) {
if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) {
LOG(debug) << " - Found at index: " << i;
return static_cast<int>(i);
}
}
return -1; // not found
}
42 changes: 39 additions & 3 deletions PWGDQ/Core/CutsLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
#ifndef PWGDQ_CORE_CUTSLIBRARY_H_
#define PWGDQ_CORE_CUTSLIBRARY_H_

#include <string>
#include <vector>
#include "PWGDQ/Core/AnalysisCut.h"
#include "PWGDQ/Core/AnalysisCompositeCut.h"
#include "PWGDQ/Core/AnalysisCut.h"
#include "PWGDQ/Core/VarManager.h"

#include <string>
#include <vector>

// ///////////////////////////////////////////////
// These are the Cuts used in the CEFP Task //
// to select tracks in the event selection //
Expand Down Expand Up @@ -119,6 +120,41 @@ bool ValidateJSONAnalysisCompositeCut(T cut);
template <typename T>
AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName);
} // namespace dqcuts
namespace dqmlcuts
{
struct BinaryBdtScoreConfig {
std::vector<std::string> inputFeatures;
std::vector<std::string> onnxFiles;
std::vector<std::pair<double, double>> binsCent; // bins for centrality
std::vector<std::pair<double, double>> binsPt; // bins for pT
std::vector<double> binsMl; // bins for flattened binning
std::string centType;
o2::framework::LabeledArray<double> cutsMl; // BDT score cuts for each bin
std::vector<int> cutDirs; // direction of the cuts on the BDT score
};

struct MultiClassBdtScoreConfig {
std::vector<std::string> inputFeatures;
std::vector<std::string> onnxFiles;
std::vector<std::pair<double, double>> binsCent;
std::vector<std::pair<double, double>> binsPt;
std::vector<double> binsMl;
std::string centType;
o2::framework::LabeledArray<double> cutsMl;
std::vector<int> cutDirs;
};

using BdtScoreConfig = std::variant<BinaryBdtScoreConfig, MultiClassBdtScoreConfig>;

BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json);

o2::framework::LabeledArray<double> makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
const std::vector<std::string>& labelsPt,
const std::vector<std::string>& labelsClass);
int getMlBinIndex(double cent, double pt,
const std::vector<std::pair<double, double>>& binsCent,
const std::vector<std::pair<double, double>>& binsPt);
} // namespace dqmlcuts
} // namespace o2::aod

AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName);
Expand Down
Loading
Loading