Skip to content

Commit 2c6afcb

Browse files
JinjooSeoJinjoo Seo
andauthored
[PWGDQ] ML response for DQ-analysis selections (#12169)
Co-authored-by: Jinjoo Seo <jseo@dhcp187-536.laptop-wlc.uni-heidelberg.de>
1 parent 1fe5610 commit 2c6afcb

File tree

11 files changed

+976
-12
lines changed

11 files changed

+976
-12
lines changed

PWGDQ/Core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore
2121
AnalysisCompositeCut.cxx
2222
MCProng.cxx
2323
MCSignal.cxx
24-
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle)
24+
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore)
2525

2626
o2physics_target_root_dictionary(PWGDQCore
2727
HEADERS AnalysisCut.h

PWGDQ/Core/CutsLibrary.cxx

Lines changed: 237 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
// Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no
1313
//
1414
#include "PWGDQ/Core/CutsLibrary.h"
15-
#include <RtypesCore.h>
16-
#include <TF1.h>
17-
#include <vector>
18-
#include <string>
19-
#include <iostream>
15+
2016
#include "AnalysisCompositeCut.h"
2117
#include "VarManager.h"
2218

19+
#include <TF1.h>
20+
21+
#include <RtypesCore.h>
22+
23+
#include <iostream>
24+
#include <set>
25+
#include <string>
26+
#include <vector>
27+
2328
using std::cout;
2429
using std::endl;
2530

@@ -7100,3 +7105,230 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons
71007105

71017106
return retCut;
71027107
}
7108+
7109+
//________________________________________________________________________________________________
7110+
o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json)
7111+
{
7112+
LOG(info) << "========================================== interpreting JSON for ML analysis cuts";
7113+
if (!json) {
7114+
LOG(fatal) << "JSON config string is null!";
7115+
return {};
7116+
}
7117+
LOG(info) << "JSON string: " << json;
7118+
7119+
rapidjson::Document document;
7120+
7121+
// Check that the json is parsed correctly
7122+
rapidjson::ParseResult ok = document.Parse(json);
7123+
if (!ok) {
7124+
LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")";
7125+
return {};
7126+
}
7127+
7128+
for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) {
7129+
const auto& obj = it->value;
7130+
7131+
// Classification type
7132+
if (!obj.HasMember("type")) {
7133+
LOG(fatal) << "Missing type (Binary/MultiClass)";
7134+
return {};
7135+
}
7136+
TString typeStr = obj["type"].GetString();
7137+
if (typeStr != "Binary" && typeStr != "MultiClass") {
7138+
LOG(fatal) << "Unsupported classification type: " << typeStr;
7139+
return {};
7140+
}
7141+
7142+
// Input features
7143+
if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) {
7144+
LOG(fatal) << "Missing inputFeatures member or array";
7145+
return {};
7146+
}
7147+
std::vector<std::string> namesInputFeatures;
7148+
for (const auto& feature : obj["inputFeatures"].GetArray()) {
7149+
namesInputFeatures.emplace_back(feature.GetString());
7150+
LOG(debug) << "Input features: " << feature.GetString();
7151+
}
7152+
7153+
// Model files
7154+
if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) {
7155+
LOG(fatal) << "Missing modelFiles member or array";
7156+
return {};
7157+
}
7158+
std::vector<std::string> onnxFileNames;
7159+
for (const auto& model : obj["modelFiles"].GetArray()) {
7160+
onnxFileNames.emplace_back(model.GetString());
7161+
LOG(debug) << "Model Files: " << model.GetString() << " ";
7162+
}
7163+
7164+
// Centrality estimation type
7165+
if (!obj.HasMember("cent") || !obj["cent"].IsString()) {
7166+
LOG(fatal) << "Missing cent member";
7167+
return {};
7168+
}
7169+
std::string cent = obj["cent"].GetString();
7170+
LOG(debug) << "Centrality type: " << cent;
7171+
if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") {
7172+
LOG(fatal) << "Unsupported centrality type: " << cent;
7173+
return {};
7174+
}
7175+
7176+
// Cut storage
7177+
std::vector<std::pair<double, double>> centBins;
7178+
std::vector<std::pair<double, double>> ptBins;
7179+
std::vector<std::vector<double>> cutsMl;
7180+
std::vector<int> cutDirs;
7181+
std::vector<std::string> labelsFlatBin;
7182+
bool cutDirsFilled = false;
7183+
7184+
for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) {
7185+
TString centKey = centMember->name.GetString();
7186+
if (!centKey.Contains("AddCentCut"))
7187+
continue;
7188+
7189+
const auto& centCut = centMember->value;
7190+
7191+
// Centrality info
7192+
if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) {
7193+
LOG(fatal) << "Missing centMin/centMax in " << centKey;
7194+
return {};
7195+
}
7196+
double centMin = centCut["centMin"].GetDouble();
7197+
double centMax = centCut["centMax"].GetDouble();
7198+
7199+
for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) {
7200+
TString ptKey = ptMember->name.GetString();
7201+
if (!ptKey.Contains("AddPtCut"))
7202+
continue;
7203+
7204+
const auto& ptCut = ptMember->value;
7205+
7206+
// Pt info
7207+
if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) {
7208+
LOG(fatal) << "Missing pTMin/pTMax in " << ptKey;
7209+
return {};
7210+
}
7211+
7212+
double ptMin = ptCut["pTMin"].GetDouble();
7213+
double ptMax = ptCut["pTMax"].GetDouble();
7214+
7215+
std::vector<double> binCuts;
7216+
bool exclude = false;
7217+
7218+
for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) {
7219+
TString mlKey = mlMember->name.GetString();
7220+
if (!mlKey.Contains("AddMLCut"))
7221+
continue;
7222+
7223+
const auto& mlcut = mlMember->value;
7224+
7225+
if (!mlcut.HasMember("cut")) {
7226+
LOG(fatal) << "Missing cut (score) in " << mlKey;
7227+
return {};
7228+
}
7229+
7230+
double cutVal = mlcut["cut"].GetDouble();
7231+
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;
7232+
7233+
binCuts.push_back(cutVal);
7234+
7235+
if (!cutDirsFilled) {
7236+
cutDirs.push_back(exclude ? 0 : 1);
7237+
}
7238+
}
7239+
7240+
if (!cutDirsFilled) {
7241+
cutDirsFilled = true;
7242+
}
7243+
7244+
centBins.emplace_back(centMin, centMax);
7245+
ptBins.emplace_back(ptMin, ptMax);
7246+
cutsMl.push_back(binCuts);
7247+
labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax));
7248+
LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: [";
7249+
for (size_t i = 0; i < binCuts.size(); ++i) {
7250+
std::cout << binCuts[i];
7251+
if (i != binCuts.size() - 1)
7252+
std::cout << ", ";
7253+
}
7254+
std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl;
7255+
}
7256+
}
7257+
7258+
if (cutDirs.size() != cutsMl[0].size()) {
7259+
LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size()
7260+
<< ", cutsMl[0].size() = " << cutDirs.size();
7261+
return {};
7262+
}
7263+
7264+
std::vector<std::string> labelsClass;
7265+
for (size_t j = 0; j < cutsMl[0].size(); ++j) {
7266+
labelsClass.push_back(Form("score class %d", static_cast<int>(j)));
7267+
}
7268+
7269+
size_t nFlatBins = cutsMl.size();
7270+
std::vector<double> binsMl(nFlatBins + 1);
7271+
std::iota(binsMl.begin(), binsMl.end(), 0);
7272+
7273+
// Binary
7274+
if (typeStr == "Binary") {
7275+
dqmlcuts::BinaryBdtScoreConfig binaryCfg;
7276+
binaryCfg.inputFeatures = namesInputFeatures;
7277+
binaryCfg.onnxFiles = onnxFileNames;
7278+
binaryCfg.binsCent = centBins;
7279+
binaryCfg.binsPt = ptBins;
7280+
binaryCfg.binsMl = binsMl;
7281+
binaryCfg.cutDirs = cutDirs;
7282+
binaryCfg.centType = cent;
7283+
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);
7284+
7285+
return binaryCfg;
7286+
7287+
// MultiClass
7288+
} else if (typeStr == "MultiClass") {
7289+
dqmlcuts::MultiClassBdtScoreConfig multiCfg;
7290+
multiCfg.inputFeatures = namesInputFeatures;
7291+
multiCfg.onnxFiles = onnxFileNames;
7292+
multiCfg.binsCent = centBins;
7293+
multiCfg.binsPt = ptBins;
7294+
multiCfg.binsMl = binsMl;
7295+
multiCfg.cutDirs = cutDirs;
7296+
multiCfg.centType = cent;
7297+
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);
7298+
7299+
return multiCfg;
7300+
}
7301+
}
7302+
7303+
return {};
7304+
}
7305+
7306+
o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
7307+
const std::vector<std::string>& labelsflatBin,
7308+
const std::vector<std::string>& labelsClass)
7309+
{
7310+
const size_t nRows = cuts.size();
7311+
const size_t nCols = cuts.empty() ? 0 : cuts[0].size();
7312+
std::vector<double> flat;
7313+
7314+
for (const auto& row : cuts) {
7315+
flat.insert(flat.end(), row.begin(), row.end());
7316+
}
7317+
7318+
o2::framework::Array2D<double> arr(flat.data(), nRows, nCols);
7319+
return o2::framework::LabeledArray<double>(arr, labelsflatBin, labelsClass);
7320+
}
7321+
7322+
int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt,
7323+
const std::vector<std::pair<double, double>>& binsCent,
7324+
const std::vector<std::pair<double, double>>& binsPt)
7325+
{
7326+
LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt;
7327+
for (size_t i = 0; i < binsCent.size(); ++i) {
7328+
if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) {
7329+
LOG(debug) << " - Found at index: " << i;
7330+
return static_cast<int>(i);
7331+
}
7332+
}
7333+
return -1; // not found
7334+
}

PWGDQ/Core/CutsLibrary.h

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
#ifndef PWGDQ_CORE_CUTSLIBRARY_H_
1616
#define PWGDQ_CORE_CUTSLIBRARY_H_
1717

18-
#include <string>
19-
#include <vector>
20-
#include "PWGDQ/Core/AnalysisCut.h"
2118
#include "PWGDQ/Core/AnalysisCompositeCut.h"
19+
#include "PWGDQ/Core/AnalysisCut.h"
2220
#include "PWGDQ/Core/VarManager.h"
2321

22+
#include <string>
23+
#include <vector>
24+
2425
// ///////////////////////////////////////////////
2526
// These are the Cuts used in the CEFP Task //
2627
// to select tracks in the event selection //
@@ -119,6 +120,41 @@ bool ValidateJSONAnalysisCompositeCut(T cut);
119120
template <typename T>
120121
AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName);
121122
} // namespace dqcuts
123+
namespace dqmlcuts
124+
{
125+
struct BinaryBdtScoreConfig {
126+
std::vector<std::string> inputFeatures;
127+
std::vector<std::string> onnxFiles;
128+
std::vector<std::pair<double, double>> binsCent; // bins for centrality
129+
std::vector<std::pair<double, double>> binsPt; // bins for pT
130+
std::vector<double> binsMl; // bins for flattened binning
131+
std::string centType;
132+
o2::framework::LabeledArray<double> cutsMl; // BDT score cuts for each bin
133+
std::vector<int> cutDirs; // direction of the cuts on the BDT score
134+
};
135+
136+
struct MultiClassBdtScoreConfig {
137+
std::vector<std::string> inputFeatures;
138+
std::vector<std::string> onnxFiles;
139+
std::vector<std::pair<double, double>> binsCent;
140+
std::vector<std::pair<double, double>> binsPt;
141+
std::vector<double> binsMl;
142+
std::string centType;
143+
o2::framework::LabeledArray<double> cutsMl;
144+
std::vector<int> cutDirs;
145+
};
146+
147+
using BdtScoreConfig = std::variant<BinaryBdtScoreConfig, MultiClassBdtScoreConfig>;
148+
149+
BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json);
150+
151+
o2::framework::LabeledArray<double> makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
152+
const std::vector<std::string>& labelsPt,
153+
const std::vector<std::string>& labelsClass);
154+
int getMlBinIndex(double cent, double pt,
155+
const std::vector<std::pair<double, double>>& binsCent,
156+
const std::vector<std::pair<double, double>>& binsPt);
157+
} // namespace dqmlcuts
122158
} // namespace o2::aod
123159

124160
AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName);

0 commit comments

Comments
 (0)