|
12 | 12 | // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no |
13 | 13 | // |
14 | 14 | #include "PWGDQ/Core/CutsLibrary.h" |
15 | | -#include <RtypesCore.h> |
16 | | -#include <TF1.h> |
17 | | -#include <vector> |
18 | | -#include <string> |
19 | | -#include <iostream> |
| 15 | + |
20 | 16 | #include "AnalysisCompositeCut.h" |
21 | 17 | #include "VarManager.h" |
22 | 18 |
|
| 19 | +#include <TF1.h> |
| 20 | + |
| 21 | +#include <RtypesCore.h> |
| 22 | + |
| 23 | +#include <iostream> |
| 24 | +#include <set> |
| 25 | +#include <string> |
| 26 | +#include <vector> |
| 27 | + |
23 | 28 | using std::cout; |
24 | 29 | using std::endl; |
25 | 30 |
|
@@ -7100,3 +7105,230 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons |
7100 | 7105 |
|
7101 | 7106 | return retCut; |
7102 | 7107 | } |
| 7108 | + |
| 7109 | +//________________________________________________________________________________________________ |
| 7110 | +o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) |
| 7111 | +{ |
| 7112 | + LOG(info) << "========================================== interpreting JSON for ML analysis cuts"; |
| 7113 | + if (!json) { |
| 7114 | + LOG(fatal) << "JSON config string is null!"; |
| 7115 | + return {}; |
| 7116 | + } |
| 7117 | + LOG(info) << "JSON string: " << json; |
| 7118 | + |
| 7119 | + rapidjson::Document document; |
| 7120 | + |
| 7121 | + // Check that the json is parsed correctly |
| 7122 | + rapidjson::ParseResult ok = document.Parse(json); |
| 7123 | + if (!ok) { |
| 7124 | + LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; |
| 7125 | + return {}; |
| 7126 | + } |
| 7127 | + |
| 7128 | + for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { |
| 7129 | + const auto& obj = it->value; |
| 7130 | + |
| 7131 | + // Classification type |
| 7132 | + if (!obj.HasMember("type")) { |
| 7133 | + LOG(fatal) << "Missing type (Binary/MultiClass)"; |
| 7134 | + return {}; |
| 7135 | + } |
| 7136 | + TString typeStr = obj["type"].GetString(); |
| 7137 | + if (typeStr != "Binary" && typeStr != "MultiClass") { |
| 7138 | + LOG(fatal) << "Unsupported classification type: " << typeStr; |
| 7139 | + return {}; |
| 7140 | + } |
| 7141 | + |
| 7142 | + // Input features |
| 7143 | + if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) { |
| 7144 | + LOG(fatal) << "Missing inputFeatures member or array"; |
| 7145 | + return {}; |
| 7146 | + } |
| 7147 | + std::vector<std::string> namesInputFeatures; |
| 7148 | + for (const auto& feature : obj["inputFeatures"].GetArray()) { |
| 7149 | + namesInputFeatures.emplace_back(feature.GetString()); |
| 7150 | + LOG(debug) << "Input features: " << feature.GetString(); |
| 7151 | + } |
| 7152 | + |
| 7153 | + // Model files |
| 7154 | + if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) { |
| 7155 | + LOG(fatal) << "Missing modelFiles member or array"; |
| 7156 | + return {}; |
| 7157 | + } |
| 7158 | + std::vector<std::string> onnxFileNames; |
| 7159 | + for (const auto& model : obj["modelFiles"].GetArray()) { |
| 7160 | + onnxFileNames.emplace_back(model.GetString()); |
| 7161 | + LOG(debug) << "Model Files: " << model.GetString() << " "; |
| 7162 | + } |
| 7163 | + |
| 7164 | + // Centrality estimation type |
| 7165 | + if (!obj.HasMember("cent") || !obj["cent"].IsString()) { |
| 7166 | + LOG(fatal) << "Missing cent member"; |
| 7167 | + return {}; |
| 7168 | + } |
| 7169 | + std::string cent = obj["cent"].GetString(); |
| 7170 | + LOG(debug) << "Centrality type: " << cent; |
| 7171 | + if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") { |
| 7172 | + LOG(fatal) << "Unsupported centrality type: " << cent; |
| 7173 | + return {}; |
| 7174 | + } |
| 7175 | + |
| 7176 | + // Cut storage |
| 7177 | + std::vector<std::pair<double, double>> centBins; |
| 7178 | + std::vector<std::pair<double, double>> ptBins; |
| 7179 | + std::vector<std::vector<double>> cutsMl; |
| 7180 | + std::vector<int> cutDirs; |
| 7181 | + std::vector<std::string> labelsFlatBin; |
| 7182 | + bool cutDirsFilled = false; |
| 7183 | + |
| 7184 | + for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) { |
| 7185 | + TString centKey = centMember->name.GetString(); |
| 7186 | + if (!centKey.Contains("AddCentCut")) |
| 7187 | + continue; |
| 7188 | + |
| 7189 | + const auto& centCut = centMember->value; |
| 7190 | + |
| 7191 | + // Centrality info |
| 7192 | + if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) { |
| 7193 | + LOG(fatal) << "Missing centMin/centMax in " << centKey; |
| 7194 | + return {}; |
| 7195 | + } |
| 7196 | + double centMin = centCut["centMin"].GetDouble(); |
| 7197 | + double centMax = centCut["centMax"].GetDouble(); |
| 7198 | + |
| 7199 | + for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) { |
| 7200 | + TString ptKey = ptMember->name.GetString(); |
| 7201 | + if (!ptKey.Contains("AddPtCut")) |
| 7202 | + continue; |
| 7203 | + |
| 7204 | + const auto& ptCut = ptMember->value; |
| 7205 | + |
| 7206 | + // Pt info |
| 7207 | + if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) { |
| 7208 | + LOG(fatal) << "Missing pTMin/pTMax in " << ptKey; |
| 7209 | + return {}; |
| 7210 | + } |
| 7211 | + |
| 7212 | + double ptMin = ptCut["pTMin"].GetDouble(); |
| 7213 | + double ptMax = ptCut["pTMax"].GetDouble(); |
| 7214 | + |
| 7215 | + std::vector<double> binCuts; |
| 7216 | + bool exclude = false; |
| 7217 | + |
| 7218 | + for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) { |
| 7219 | + TString mlKey = mlMember->name.GetString(); |
| 7220 | + if (!mlKey.Contains("AddMLCut")) |
| 7221 | + continue; |
| 7222 | + |
| 7223 | + const auto& mlcut = mlMember->value; |
| 7224 | + |
| 7225 | + if (!mlcut.HasMember("cut")) { |
| 7226 | + LOG(fatal) << "Missing cut (score) in " << mlKey; |
| 7227 | + return {}; |
| 7228 | + } |
| 7229 | + |
| 7230 | + double cutVal = mlcut["cut"].GetDouble(); |
| 7231 | + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; |
| 7232 | + |
| 7233 | + binCuts.push_back(cutVal); |
| 7234 | + |
| 7235 | + if (!cutDirsFilled) { |
| 7236 | + cutDirs.push_back(exclude ? 0 : 1); |
| 7237 | + } |
| 7238 | + } |
| 7239 | + |
| 7240 | + if (!cutDirsFilled) { |
| 7241 | + cutDirsFilled = true; |
| 7242 | + } |
| 7243 | + |
| 7244 | + centBins.emplace_back(centMin, centMax); |
| 7245 | + ptBins.emplace_back(ptMin, ptMax); |
| 7246 | + cutsMl.push_back(binCuts); |
| 7247 | + labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax)); |
| 7248 | + LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: ["; |
| 7249 | + for (size_t i = 0; i < binCuts.size(); ++i) { |
| 7250 | + std::cout << binCuts[i]; |
| 7251 | + if (i != binCuts.size() - 1) |
| 7252 | + std::cout << ", "; |
| 7253 | + } |
| 7254 | + std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl; |
| 7255 | + } |
| 7256 | + } |
| 7257 | + |
| 7258 | + if (cutDirs.size() != cutsMl[0].size()) { |
| 7259 | + LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size() |
| 7260 | + << ", cutsMl[0].size() = " << cutDirs.size(); |
| 7261 | + return {}; |
| 7262 | + } |
| 7263 | + |
| 7264 | + std::vector<std::string> labelsClass; |
| 7265 | + for (size_t j = 0; j < cutsMl[0].size(); ++j) { |
| 7266 | + labelsClass.push_back(Form("score class %d", static_cast<int>(j))); |
| 7267 | + } |
| 7268 | + |
| 7269 | + size_t nFlatBins = cutsMl.size(); |
| 7270 | + std::vector<double> binsMl(nFlatBins + 1); |
| 7271 | + std::iota(binsMl.begin(), binsMl.end(), 0); |
| 7272 | + |
| 7273 | + // Binary |
| 7274 | + if (typeStr == "Binary") { |
| 7275 | + dqmlcuts::BinaryBdtScoreConfig binaryCfg; |
| 7276 | + binaryCfg.inputFeatures = namesInputFeatures; |
| 7277 | + binaryCfg.onnxFiles = onnxFileNames; |
| 7278 | + binaryCfg.binsCent = centBins; |
| 7279 | + binaryCfg.binsPt = ptBins; |
| 7280 | + binaryCfg.binsMl = binsMl; |
| 7281 | + binaryCfg.cutDirs = cutDirs; |
| 7282 | + binaryCfg.centType = cent; |
| 7283 | + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); |
| 7284 | + |
| 7285 | + return binaryCfg; |
| 7286 | + |
| 7287 | + // MultiClass |
| 7288 | + } else if (typeStr == "MultiClass") { |
| 7289 | + dqmlcuts::MultiClassBdtScoreConfig multiCfg; |
| 7290 | + multiCfg.inputFeatures = namesInputFeatures; |
| 7291 | + multiCfg.onnxFiles = onnxFileNames; |
| 7292 | + multiCfg.binsCent = centBins; |
| 7293 | + multiCfg.binsPt = ptBins; |
| 7294 | + multiCfg.binsMl = binsMl; |
| 7295 | + multiCfg.cutDirs = cutDirs; |
| 7296 | + multiCfg.centType = cent; |
| 7297 | + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); |
| 7298 | + |
| 7299 | + return multiCfg; |
| 7300 | + } |
| 7301 | + } |
| 7302 | + |
| 7303 | + return {}; |
| 7304 | +} |
| 7305 | + |
| 7306 | +o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts, |
| 7307 | + const std::vector<std::string>& labelsflatBin, |
| 7308 | + const std::vector<std::string>& labelsClass) |
| 7309 | +{ |
| 7310 | + const size_t nRows = cuts.size(); |
| 7311 | + const size_t nCols = cuts.empty() ? 0 : cuts[0].size(); |
| 7312 | + std::vector<double> flat; |
| 7313 | + |
| 7314 | + for (const auto& row : cuts) { |
| 7315 | + flat.insert(flat.end(), row.begin(), row.end()); |
| 7316 | + } |
| 7317 | + |
| 7318 | + o2::framework::Array2D<double> arr(flat.data(), nRows, nCols); |
| 7319 | + return o2::framework::LabeledArray<double>(arr, labelsflatBin, labelsClass); |
| 7320 | +} |
| 7321 | + |
| 7322 | +int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt, |
| 7323 | + const std::vector<std::pair<double, double>>& binsCent, |
| 7324 | + const std::vector<std::pair<double, double>>& binsPt) |
| 7325 | +{ |
| 7326 | + LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; |
| 7327 | + for (size_t i = 0; i < binsCent.size(); ++i) { |
| 7328 | + if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) { |
| 7329 | + LOG(debug) << " - Found at index: " << i; |
| 7330 | + return static_cast<int>(i); |
| 7331 | + } |
| 7332 | + } |
| 7333 | + return -1; // not found |
| 7334 | +} |
0 commit comments