Skip to content

Commit f447df3

Browse files
Jinjoo SeoJinjooSeo
authored andcommitted
Implement ML-based response class (DQMlResponse) for dielectron DQ-analysis selections.
Supports both binary and multiclass BDT evaluation using ONNX Example JSON files included for reference
1 parent 0af789a commit f447df3

File tree

11 files changed

+722
-6
lines changed

11 files changed

+722
-6
lines changed

PWGDQ/Core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore
2121
AnalysisCompositeCut.cxx
2222
MCProng.cxx
2323
MCSignal.cxx
24-
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle)
24+
PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore)
2525

2626
o2physics_target_root_dictionary(PWGDQCore
2727
HEADERS AnalysisCut.h

PWGDQ/Core/CutsLibrary.cxx

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <vector>
1818
#include <string>
1919
#include <iostream>
20+
#include <set>
2021
#include "AnalysisCompositeCut.h"
2122
#include "VarManager.h"
2223

@@ -7100,3 +7101,154 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons
71007101

71017102
return retCut;
71027103
}
7104+
7105+
//________________________________________________________________________________________________
7106+
o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json)
7107+
{
7108+
LOG(info) << "========================================== interpreting JSON for analysis cuts";
7109+
LOG(info) << "JSON string: " << json;
7110+
7111+
rapidjson::Document document;
7112+
rapidjson::ParseResult ok = document.Parse(json);
7113+
if (!ok) {
7114+
LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")";
7115+
return {}; // empty variant
7116+
}
7117+
7118+
for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) {
7119+
const auto& obj = it->value;
7120+
7121+
if (!obj.HasMember("type")) {
7122+
LOG(fatal) << "Missing type (Binary/MultiClass)";
7123+
return {};
7124+
}
7125+
7126+
TString typeStr = obj["type"].GetString();
7127+
// int nClasses = (typeStr == "MultiClass") ? 3 : 1;
7128+
7129+
std::vector<std::string> namesInputFeatures;
7130+
if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) {
7131+
for (auto& feature : obj["inputFeatures"].GetArray()) {
7132+
namesInputFeatures.emplace_back(feature.GetString());
7133+
}
7134+
}
7135+
7136+
std::vector<std::string> onnxFileNames;
7137+
if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) {
7138+
for (const auto& model : obj["modelFiles"].GetArray()) {
7139+
onnxFileNames.emplace_back(model.GetString());
7140+
}
7141+
}
7142+
7143+
// Cut storage
7144+
std::vector<std::pair<double, double>> ptBins;
7145+
std::vector<std::vector<double>> cutsMl;
7146+
std::vector<int> cutDirs;
7147+
bool cutDirsFilled = false;
7148+
7149+
for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) {
7150+
TString key = member->name.GetString();
7151+
if (!key.Contains("AddCut"))
7152+
continue;
7153+
7154+
const auto& cut = member->value;
7155+
7156+
if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) {
7157+
LOG(fatal) << "Missing pTMin/pTMax in ML cut";
7158+
return {};
7159+
}
7160+
7161+
double pTMin = cut["pTMin"].GetDouble();
7162+
double pTMax = cut["pTMax"].GetDouble();
7163+
ptBins.emplace_back(pTMin, pTMax);
7164+
7165+
std::vector<double> binCuts;
7166+
bool exclude = false;
7167+
7168+
for (auto& sub : cut.GetObject()) {
7169+
TString subKey = sub.name.GetString();
7170+
if (!subKey.Contains("AddMLCut"))
7171+
continue;
7172+
7173+
const auto& mlcut = sub.value;
7174+
// const char* var = mlcut["var"].GetString();
7175+
double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5;
7176+
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;
7177+
7178+
binCuts.push_back(cutVal);
7179+
7180+
if (!cutDirsFilled) {
7181+
cutDirs.push_back(exclude ? 1 : 0);
7182+
cutDirsFilled = true;
7183+
}
7184+
}
7185+
7186+
cutsMl.push_back(binCuts);
7187+
}
7188+
7189+
// bin edges
7190+
std::vector<double> binsPt;
7191+
if (!ptBins.empty()) {
7192+
std::set<double> binEdges;
7193+
for (auto& b : ptBins)
7194+
binEdges.insert(b.first);
7195+
binEdges.insert(ptBins.back().second);
7196+
binsPt = std::vector<double>(binEdges.begin(), binEdges.end());
7197+
} else {
7198+
LOG(fatal) << "No pT bins found in ML cuts";
7199+
return {};
7200+
}
7201+
7202+
std::vector<std::string> labelsPt, labelsClass;
7203+
for (size_t i = 0; i < cutsMl.size(); ++i) {
7204+
labelsPt.push_back(Form("pT%.1f", binsPt[i]));
7205+
}
7206+
for (size_t j = 0; j < cutsMl[0].size(); ++j) {
7207+
labelsClass.push_back(Form("cls%d", static_cast<int>(j)));
7208+
}
7209+
7210+
// Binary
7211+
if (typeStr == "Binary") {
7212+
dqmlcuts::BinaryBdtScoreConfig binaryCfg;
7213+
binaryCfg.inputFeatures = namesInputFeatures;
7214+
binaryCfg.onnxFiles = onnxFileNames;
7215+
binaryCfg.binsPt = binsPt;
7216+
binaryCfg.cutDirs = cutDirs;
7217+
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);
7218+
7219+
return binaryCfg;
7220+
7221+
// MultiClass
7222+
} else if (typeStr == "MultiClass") {
7223+
dqmlcuts::MultiClassBdtScoreConfig multiCfg;
7224+
multiCfg.inputFeatures = namesInputFeatures;
7225+
multiCfg.onnxFiles = onnxFileNames;
7226+
multiCfg.binsPt = binsPt;
7227+
multiCfg.cutDirs = cutDirs;
7228+
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);
7229+
7230+
return multiCfg;
7231+
}
7232+
7233+
LOG(fatal) << "Unsupported classification type: " << typeStr;
7234+
return {};
7235+
}
7236+
7237+
return {};
7238+
}
7239+
7240+
o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
7241+
const std::vector<std::string>& labelsPt,
7242+
const std::vector<std::string>& labelsClass)
7243+
{
7244+
const size_t nRows = cuts.size();
7245+
const size_t nCols = cuts.empty() ? 0 : cuts[0].size();
7246+
std::vector<double> flat;
7247+
7248+
for (const auto& row : cuts) {
7249+
flat.insert(flat.end(), row.begin(), row.end());
7250+
}
7251+
7252+
o2::framework::Array2D<double> arr(flat.data(), nRows, nCols);
7253+
return o2::framework::LabeledArray<double>(arr, labelsPt, labelsClass);
7254+
}

PWGDQ/Core/CutsLibrary.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,32 @@ bool ValidateJSONAnalysisCompositeCut(T cut);
119119
template <typename T>
120120
AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName);
121121
} // namespace dqcuts
122+
namespace dqmlcuts
123+
{
124+
struct BinaryBdtScoreConfig {
125+
std::vector<std::string> inputFeatures;
126+
std::vector<std::string> onnxFiles;
127+
std::vector<double> binsPt;
128+
o2::framework::LabeledArray<double> cutsMl;
129+
std::vector<int> cutDirs;
130+
};
131+
132+
struct MultiClassBdtScoreConfig {
133+
std::vector<std::string> inputFeatures;
134+
std::vector<std::string> onnxFiles;
135+
std::vector<double> binsPt;
136+
o2::framework::LabeledArray<double> cutsMl;
137+
std::vector<int> cutDirs;
138+
};
139+
140+
using BdtScoreConfig = std::variant<BinaryBdtScoreConfig, MultiClassBdtScoreConfig>;
141+
142+
BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json);
143+
144+
o2::framework::LabeledArray<double> makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
145+
const std::vector<std::string>& labelsPt,
146+
const std::vector<std::string>& labelsClass);
147+
} // namespace dqmlcuts
122148
} // namespace o2::aod
123149

124150
AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName);

0 commit comments

Comments
 (0)