Skip to content

Commit 5707bee

Browse files
committed
Address comments
1 parent 03a3844 commit 5707bee

File tree

7 files changed

+581
-339
lines changed

7 files changed

+581
-339
lines changed

PWGDQ/Core/CutsLibrary.cxx

Lines changed: 136 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7109,140 +7109,202 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons
71097109
//________________________________________________________________________________________________
71107110
o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json)
71117111
{
7112-
LOG(info) << "========================================== interpreting JSON for analysis cuts";
7112+
LOG(info) << "========================================== interpreting JSON for ML analysis cuts";
7113+
if (!json) {
7114+
LOG(fatal) << "JSON config string is null!";
7115+
return {};
7116+
}
71137117
LOG(info) << "JSON string: " << json;
71147118

71157119
rapidjson::Document document;
7120+
7121+
// Check that the json is parsed correctly
71167122
rapidjson::ParseResult ok = document.Parse(json);
71177123
if (!ok) {
71187124
LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")";
7119-
return {}; // empty variant
7125+
return {};
71207126
}
71217127

71227128
for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) {
71237129
const auto& obj = it->value;
71247130

7131+
// Classification type
71257132
if (!obj.HasMember("type")) {
71267133
LOG(fatal) << "Missing type (Binary/MultiClass)";
71277134
return {};
71287135
}
7129-
71307136
TString typeStr = obj["type"].GetString();
7131-
// int nClasses = (typeStr == "MultiClass") ? 3 : 1;
7137+
if (typeStr != "Binary" && typeStr != "MultiClass") {
7138+
LOG(fatal) << "Unsupported classification type: " << typeStr;
7139+
return {};
7140+
}
71327141

7142+
// Input features
7143+
if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) {
7144+
LOG(fatal) << "Missing inputFeatures member or array";
7145+
return {};
7146+
}
71337147
std::vector<std::string> namesInputFeatures;
7134-
if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) {
7135-
for (const auto& feature : obj["inputFeatures"].GetArray()) {
7136-
namesInputFeatures.emplace_back(feature.GetString());
7137-
}
7148+
for (const auto& feature : obj["inputFeatures"].GetArray()) {
7149+
namesInputFeatures.emplace_back(feature.GetString());
7150+
LOG(debug) << "Input features: " << feature.GetString();
71387151
}
71397152

7153+
// Model files
7154+
if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) {
7155+
LOG(fatal) << "Missing modelFiles member or array";
7156+
return {};
7157+
}
71407158
std::vector<std::string> onnxFileNames;
7141-
if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) {
7142-
for (const auto& model : obj["modelFiles"].GetArray()) {
7143-
onnxFileNames.emplace_back(model.GetString());
7144-
}
7159+
for (const auto& model : obj["modelFiles"].GetArray()) {
7160+
onnxFileNames.emplace_back(model.GetString());
7161+
LOG(debug) << "Model Files: " << model.GetString() << " ";
7162+
}
7163+
7164+
// Centrality estimation type
7165+
if (!obj.HasMember("cent") || !obj["cent"].IsString()) {
7166+
LOG(fatal) << "Missing cent member";
7167+
return {};
7168+
}
7169+
std::string cent = obj["cent"].GetString();
7170+
LOG(debug) << "Centrality type: " << cent;
7171+
if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") {
7172+
LOG(fatal) << "Unsupported centrality type: " << cent;
7173+
return {};
71457174
}
71467175

71477176
// Cut storage
7177+
std::vector<std::pair<double, double>> centBins;
71487178
std::vector<std::pair<double, double>> ptBins;
71497179
std::vector<std::vector<double>> cutsMl;
71507180
std::vector<int> cutDirs;
7181+
std::vector<std::string> labelsFlatBin;
71517182
bool cutDirsFilled = false;
71527183

7153-
for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) {
7154-
TString key = member->name.GetString();
7155-
if (!key.Contains("AddCut"))
7184+
for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) {
7185+
TString centKey = centMember->name.GetString();
7186+
if (!centKey.Contains("AddCentCut"))
71567187
continue;
71577188

7158-
const auto& cut = member->value;
7189+
const auto& centCut = centMember->value;
71597190

7160-
if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) {
7161-
LOG(fatal) << "Missing pTMin/pTMax in ML cut";
7191+
// Centrality info
7192+
if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) {
7193+
LOG(fatal) << "Missing centMin/centMax in " << centKey;
71627194
return {};
71637195
}
7196+
double centMin = centCut["centMin"].GetDouble();
7197+
double centMax = centCut["centMax"].GetDouble();
71647198

7165-
double pTMin = cut["pTMin"].GetDouble();
7166-
double pTMax = cut["pTMax"].GetDouble();
7167-
ptBins.emplace_back(pTMin, pTMax);
7199+
for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) {
7200+
TString ptKey = ptMember->name.GetString();
7201+
if (!ptKey.Contains("AddPtCut"))
7202+
continue;
71687203

7169-
std::vector<double> binCuts;
7170-
bool exclude = false;
7204+
const auto& ptCut = ptMember->value;
71717205

7172-
for (const auto& sub : cut.GetObject()) {
7173-
TString subKey = sub.name.GetString();
7174-
if (!subKey.Contains("AddMLCut"))
7175-
continue;
7206+
// Pt info
7207+
if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) {
7208+
LOG(fatal) << "Missing pTMin/pTMax in " << ptKey;
7209+
return {};
7210+
}
7211+
7212+
double ptMin = ptCut["pTMin"].GetDouble();
7213+
double ptMax = ptCut["pTMax"].GetDouble();
7214+
7215+
std::vector<double> binCuts;
7216+
bool exclude = false;
7217+
7218+
for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) {
7219+
TString mlKey = mlMember->name.GetString();
7220+
if (!mlKey.Contains("AddMLCut"))
7221+
continue;
71767222

7177-
const auto& mlcut = sub.value;
7178-
// const char* var = mlcut["var"].GetString();
7179-
double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5;
7180-
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;
7223+
const auto& mlcut = mlMember->value;
71817224

7182-
binCuts.push_back(cutVal);
7225+
if (!mlcut.HasMember("cut")) {
7226+
LOG(fatal) << "Missing cut (score) in " << mlKey;
7227+
return {};
7228+
}
7229+
7230+
double cutVal = mlcut["cut"].GetDouble();
7231+
exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false;
7232+
7233+
binCuts.push_back(cutVal);
7234+
7235+
if (!cutDirsFilled) {
7236+
cutDirs.push_back(exclude ? 0 : 1);
7237+
}
7238+
}
71837239

71847240
if (!cutDirsFilled) {
7185-
cutDirs.push_back(exclude ? 1 : 0);
71867241
cutDirsFilled = true;
71877242
}
7188-
}
71897243

7190-
cutsMl.push_back(binCuts);
7244+
centBins.emplace_back(centMin, centMax);
7245+
ptBins.emplace_back(ptMin, ptMax);
7246+
cutsMl.push_back(binCuts);
7247+
labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax));
7248+
LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: [";
7249+
for (size_t i = 0; i < binCuts.size(); ++i) {
7250+
std::cout << binCuts[i];
7251+
if (i != binCuts.size() - 1)
7252+
std::cout << ", ";
7253+
}
7254+
std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl;
7255+
}
71917256
}
71927257

7193-
// bin edges
7194-
std::vector<double> binsPt;
7195-
if (!ptBins.empty()) {
7196-
std::set<double> binEdges;
7197-
for (const auto& b : ptBins)
7198-
binEdges.insert(b.first);
7199-
binEdges.insert(ptBins.back().second);
7200-
binsPt = std::vector<double>(binEdges.begin(), binEdges.end());
7201-
} else {
7202-
LOG(fatal) << "No pT bins found in ML cuts";
7258+
if (cutDirs.size() != cutsMl[0].size()) {
7259+
LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size()
7260+
<< ", cutsMl[0].size() = " << cutDirs.size();
72037261
return {};
72047262
}
72057263

7206-
std::vector<std::string> labelsPt, labelsClass;
7207-
for (size_t i = 0; i < cutsMl.size(); ++i) {
7208-
labelsPt.push_back(Form("pT%.1f", binsPt[i]));
7209-
}
7264+
std::vector<std::string> labelsClass;
72107265
for (size_t j = 0; j < cutsMl[0].size(); ++j) {
7211-
labelsClass.push_back(Form("cls%d", static_cast<int>(j)));
7266+
labelsClass.push_back(Form("score class %d", static_cast<int>(j)));
72127267
}
72137268

7269+
size_t nFlatBins = cutsMl.size();
7270+
std::vector<double> binsMl(nFlatBins + 1);
7271+
std::iota(binsMl.begin(), binsMl.end(), 0);
7272+
72147273
// Binary
72157274
if (typeStr == "Binary") {
72167275
dqmlcuts::BinaryBdtScoreConfig binaryCfg;
72177276
binaryCfg.inputFeatures = namesInputFeatures;
72187277
binaryCfg.onnxFiles = onnxFileNames;
7219-
binaryCfg.binsPt = binsPt;
7278+
binaryCfg.binsCent = centBins;
7279+
binaryCfg.binsPt = ptBins;
7280+
binaryCfg.binsMl = binsMl;
72207281
binaryCfg.cutDirs = cutDirs;
7221-
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);
7282+
binaryCfg.centType = cent;
7283+
binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);
72227284

72237285
return binaryCfg;
72247286

7225-
// MultiClass
7287+
// MultiClass
72267288
} else if (typeStr == "MultiClass") {
72277289
dqmlcuts::MultiClassBdtScoreConfig multiCfg;
72287290
multiCfg.inputFeatures = namesInputFeatures;
72297291
multiCfg.onnxFiles = onnxFileNames;
7230-
multiCfg.binsPt = binsPt;
7292+
multiCfg.binsCent = centBins;
7293+
multiCfg.binsPt = ptBins;
7294+
multiCfg.binsMl = binsMl;
72317295
multiCfg.cutDirs = cutDirs;
7232-
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass);
7296+
multiCfg.centType = cent;
7297+
multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass);
72337298

72347299
return multiCfg;
72357300
}
7236-
7237-
LOG(fatal) << "Unsupported classification type: " << typeStr;
7238-
return {};
72397301
}
72407302

72417303
return {};
72427304
}
72437305

72447306
o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
7245-
const std::vector<std::string>& labelsPt,
7307+
const std::vector<std::string>& labelsflatBin,
72467308
const std::vector<std::string>& labelsClass)
72477309
{
72487310
const size_t nRows = cuts.size();
@@ -7254,5 +7316,19 @@ o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const s
72547316
}
72557317

72567318
o2::framework::Array2D<double> arr(flat.data(), nRows, nCols);
7257-
return o2::framework::LabeledArray<double>(arr, labelsPt, labelsClass);
7319+
return o2::framework::LabeledArray<double>(arr, labelsflatBin, labelsClass);
72587320
}
7321+
7322+
int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt,
7323+
const std::vector<std::pair<double, double>>& binsCent,
7324+
const std::vector<std::pair<double, double>>& binsPt)
7325+
{
7326+
LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; //here
7327+
for (size_t i = 0; i < binsCent.size(); ++i) {
7328+
if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) {
7329+
LOG(debug) << " - Found at index: " << i; //here
7330+
return static_cast<int>(i);
7331+
}
7332+
}
7333+
return -1; // not found
7334+
}

PWGDQ/Core/CutsLibrary.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,21 @@ namespace dqmlcuts
124124
struct BinaryBdtScoreConfig {
125125
std::vector<std::string> inputFeatures;
126126
std::vector<std::string> onnxFiles;
127-
std::vector<double> binsPt;
128-
o2::framework::LabeledArray<double> cutsMl;
129-
std::vector<int> cutDirs;
127+
std::vector<std::pair<double, double>> binsCent; // bins for centrality
128+
std::vector<std::pair<double, double>> binsPt; // bins for pT
129+
std::vector<double> binsMl; // bins for flattened binning
130+
std::string centType;
131+
o2::framework::LabeledArray<double> cutsMl; // BDT score cuts for each bin
132+
std::vector<int> cutDirs; // direction of the cuts on the BDT score
130133
};
131134

132135
struct MultiClassBdtScoreConfig {
133136
std::vector<std::string> inputFeatures;
134137
std::vector<std::string> onnxFiles;
135-
std::vector<double> binsPt;
138+
std::vector<std::pair<double, double>> binsCent;
139+
std::vector<std::pair<double, double>> binsPt;
140+
std::vector<double> binsMl;
141+
std::string centType;
136142
o2::framework::LabeledArray<double> cutsMl;
137143
std::vector<int> cutDirs;
138144
};
@@ -144,6 +150,9 @@ BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json);
144150
o2::framework::LabeledArray<double> makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts,
145151
const std::vector<std::string>& labelsPt,
146152
const std::vector<std::string>& labelsClass);
153+
int getMlBinIndex(double cent, double pt,
154+
const std::vector<std::pair<double, double>>& binsCent,
155+
const std::vector<std::pair<double, double>>& binsPt);
147156
} // namespace dqmlcuts
148157
} // namespace o2::aod
149158

0 commit comments

Comments
 (0)