|
17 | 17 | #include <vector> |
18 | 18 | #include <string> |
19 | 19 | #include <iostream> |
| 20 | +#include <set> |
20 | 21 | #include "AnalysisCompositeCut.h" |
21 | 22 | #include "VarManager.h" |
22 | 23 |
|
@@ -7100,3 +7101,154 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons |
7100 | 7101 |
|
7101 | 7102 | return retCut; |
7102 | 7103 | } |
| 7104 | + |
| 7105 | +//________________________________________________________________________________________________ |
| 7106 | +o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) |
| 7107 | +{ |
| 7108 | + LOG(info) << "========================================== interpreting JSON for analysis cuts"; |
| 7109 | + LOG(info) << "JSON string: " << json; |
| 7110 | + |
| 7111 | + rapidjson::Document document; |
| 7112 | + rapidjson::ParseResult ok = document.Parse(json); |
| 7113 | + if (!ok) { |
| 7114 | + LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; |
| 7115 | + return {}; // empty variant |
| 7116 | + } |
| 7117 | + |
| 7118 | + for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { |
| 7119 | + const auto& obj = it->value; |
| 7120 | + |
| 7121 | + if (!obj.HasMember("type")) { |
| 7122 | + LOG(fatal) << "Missing type (Binary/MultiClass)"; |
| 7123 | + return {}; |
| 7124 | + } |
| 7125 | + |
| 7126 | + TString typeStr = obj["type"].GetString(); |
| 7127 | + // int nClasses = (typeStr == "MultiClass") ? 3 : 1; |
| 7128 | + |
| 7129 | + std::vector<std::string> namesInputFeatures; |
| 7130 | + if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) { |
| 7131 | + for (auto& feature : obj["inputFeatures"].GetArray()) { |
| 7132 | + namesInputFeatures.emplace_back(feature.GetString()); |
| 7133 | + } |
| 7134 | + } |
| 7135 | + |
| 7136 | + std::vector<std::string> onnxFileNames; |
| 7137 | + if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) { |
| 7138 | + for (const auto& model : obj["modelFiles"].GetArray()) { |
| 7139 | + onnxFileNames.emplace_back(model.GetString()); |
| 7140 | + } |
| 7141 | + } |
| 7142 | + |
| 7143 | + // Cut storage |
| 7144 | + std::vector<std::pair<double, double>> ptBins; |
| 7145 | + std::vector<std::vector<double>> cutsMl; |
| 7146 | + std::vector<int> cutDirs; |
| 7147 | + bool cutDirsFilled = false; |
| 7148 | + |
| 7149 | + for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) { |
| 7150 | + TString key = member->name.GetString(); |
| 7151 | + if (!key.Contains("AddCut")) |
| 7152 | + continue; |
| 7153 | + |
| 7154 | + const auto& cut = member->value; |
| 7155 | + |
| 7156 | + if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) { |
| 7157 | + LOG(fatal) << "Missing pTMin/pTMax in ML cut"; |
| 7158 | + return {}; |
| 7159 | + } |
| 7160 | + |
| 7161 | + double pTMin = cut["pTMin"].GetDouble(); |
| 7162 | + double pTMax = cut["pTMax"].GetDouble(); |
| 7163 | + ptBins.emplace_back(pTMin, pTMax); |
| 7164 | + |
| 7165 | + std::vector<double> binCuts; |
| 7166 | + bool exclude = false; |
| 7167 | + |
| 7168 | + for (auto& sub : cut.GetObject()) { |
| 7169 | + TString subKey = sub.name.GetString(); |
| 7170 | + if (!subKey.Contains("AddMLCut")) |
| 7171 | + continue; |
| 7172 | + |
| 7173 | + const auto& mlcut = sub.value; |
| 7174 | + // const char* var = mlcut["var"].GetString(); |
| 7175 | + double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5; |
| 7176 | + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; |
| 7177 | + |
| 7178 | + binCuts.push_back(cutVal); |
| 7179 | + |
| 7180 | + if (!cutDirsFilled) { |
| 7181 | + cutDirs.push_back(exclude ? 1 : 0); |
| 7182 | + cutDirsFilled = true; |
| 7183 | + } |
| 7184 | + } |
| 7185 | + |
| 7186 | + cutsMl.push_back(binCuts); |
| 7187 | + } |
| 7188 | + |
| 7189 | + // bin edges |
| 7190 | + std::vector<double> binsPt; |
| 7191 | + if (!ptBins.empty()) { |
| 7192 | + std::set<double> binEdges; |
| 7193 | + for (auto& b : ptBins) |
| 7194 | + binEdges.insert(b.first); |
| 7195 | + binEdges.insert(ptBins.back().second); |
| 7196 | + binsPt = std::vector<double>(binEdges.begin(), binEdges.end()); |
| 7197 | + } else { |
| 7198 | + LOG(fatal) << "No pT bins found in ML cuts"; |
| 7199 | + return {}; |
| 7200 | + } |
| 7201 | + |
| 7202 | + std::vector<std::string> labelsPt, labelsClass; |
| 7203 | + for (size_t i = 0; i < cutsMl.size(); ++i) { |
| 7204 | + labelsPt.push_back(Form("pT%.1f", binsPt[i])); |
| 7205 | + } |
| 7206 | + for (size_t j = 0; j < cutsMl[0].size(); ++j) { |
| 7207 | + labelsClass.push_back(Form("cls%d", static_cast<int>(j))); |
| 7208 | + } |
| 7209 | + |
| 7210 | + // Binary |
| 7211 | + if (typeStr == "Binary") { |
| 7212 | + dqmlcuts::BinaryBdtScoreConfig binaryCfg; |
| 7213 | + binaryCfg.inputFeatures = namesInputFeatures; |
| 7214 | + binaryCfg.onnxFiles = onnxFileNames; |
| 7215 | + binaryCfg.binsPt = binsPt; |
| 7216 | + binaryCfg.cutDirs = cutDirs; |
| 7217 | + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); |
| 7218 | + |
| 7219 | + return binaryCfg; |
| 7220 | + |
| 7221 | + // MultiClass |
| 7222 | + } else if (typeStr == "MultiClass") { |
| 7223 | + dqmlcuts::MultiClassBdtScoreConfig multiCfg; |
| 7224 | + multiCfg.inputFeatures = namesInputFeatures; |
| 7225 | + multiCfg.onnxFiles = onnxFileNames; |
| 7226 | + multiCfg.binsPt = binsPt; |
| 7227 | + multiCfg.cutDirs = cutDirs; |
| 7228 | + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); |
| 7229 | + |
| 7230 | + return multiCfg; |
| 7231 | + } |
| 7232 | + |
| 7233 | + LOG(fatal) << "Unsupported classification type: " << typeStr; |
| 7234 | + return {}; |
| 7235 | + } |
| 7236 | + |
| 7237 | + return {}; |
| 7238 | +} |
| 7239 | + |
| 7240 | +o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector<std::vector<double>>& cuts, |
| 7241 | + const std::vector<std::string>& labelsPt, |
| 7242 | + const std::vector<std::string>& labelsClass) |
| 7243 | +{ |
| 7244 | + const size_t nRows = cuts.size(); |
| 7245 | + const size_t nCols = cuts.empty() ? 0 : cuts[0].size(); |
| 7246 | + std::vector<double> flat; |
| 7247 | + |
| 7248 | + for (const auto& row : cuts) { |
| 7249 | + flat.insert(flat.end(), row.begin(), row.end()); |
| 7250 | + } |
| 7251 | + |
| 7252 | + o2::framework::Array2D<double> arr(flat.data(), nRows, nCols); |
| 7253 | + return o2::framework::LabeledArray<double>(arr, labelsPt, labelsClass); |
| 7254 | +} |
0 commit comments