@@ -7109,140 +7109,202 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons
71097109// ________________________________________________________________________________________________
71107110o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON (const char * json)
71117111{
7112- LOG (info) << " ========================================== interpreting JSON for analysis cuts" ;
7112+ LOG (info) << " ========================================== interpreting JSON for ML analysis cuts" ;
7113+ if (!json) {
7114+ LOG (fatal) << " JSON config string is null!" ;
7115+ return {};
7116+ }
71137117 LOG (info) << " JSON string: " << json;
71147118
71157119 rapidjson::Document document;
7120+
7121+ // Check that the json is parsed correctly
71167122 rapidjson::ParseResult ok = document.Parse (json);
71177123 if (!ok) {
71187124 LOG (fatal) << " JSON parse error: " << rapidjson::GetParseErrorFunc (ok.Code ()) << " (" << ok.Offset () << " )" ;
7119- return {}; // empty variant
7125+ return {};
71207126 }
71217127
71227128 for (auto it = document.MemberBegin (); it != document.MemberEnd (); ++it) {
71237129 const auto & obj = it->value ;
71247130
7131+ // Classification type
71257132 if (!obj.HasMember (" type" )) {
71267133 LOG (fatal) << " Missing type (Binary/MultiClass)" ;
71277134 return {};
71287135 }
7129-
71307136 TString typeStr = obj[" type" ].GetString ();
7131- // int nClasses = (typeStr == "MultiClass") ? 3 : 1;
7137+ if (typeStr != " Binary" && typeStr != " MultiClass" ) {
7138+ LOG (fatal) << " Unsupported classification type: " << typeStr;
7139+ return {};
7140+ }
71327141
7142+ // Input features
7143+ if (!obj.HasMember (" inputFeatures" ) || !obj[" inputFeatures" ].IsArray ()) {
7144+ LOG (fatal) << " Missing inputFeatures member or array" ;
7145+ return {};
7146+ }
71337147 std::vector<std::string> namesInputFeatures;
7134- if (obj.HasMember (" inputFeatures" ) && obj[" inputFeatures" ].IsArray ()) {
7135- for (const auto & feature : obj[" inputFeatures" ].GetArray ()) {
7136- namesInputFeatures.emplace_back (feature.GetString ());
7137- }
7148+ for (const auto & feature : obj[" inputFeatures" ].GetArray ()) {
7149+ namesInputFeatures.emplace_back (feature.GetString ());
7150+ LOG (debug) << " Input features: " << feature.GetString ();
71387151 }
71397152
7153+ // Model files
7154+ if (!obj.HasMember (" modelFiles" ) || !obj[" modelFiles" ].IsArray ()) {
7155+ LOG (fatal) << " Missing modelFiles member or array" ;
7156+ return {};
7157+ }
71407158 std::vector<std::string> onnxFileNames;
7141- if (obj.HasMember (" modelFiles" ) && obj[" modelFiles" ].IsArray ()) {
7142- for (const auto & model : obj[" modelFiles" ].GetArray ()) {
7143- onnxFileNames.emplace_back (model.GetString ());
7144- }
7159+ for (const auto & model : obj[" modelFiles" ].GetArray ()) {
7160+ onnxFileNames.emplace_back (model.GetString ());
7161+ LOG (debug) << " Model Files: " << model.GetString () << " " ;
7162+ }
7163+
7164+ // Centrality estimation type
7165+ if (!obj.HasMember (" cent" ) || !obj[" cent" ].IsString ()) {
7166+ LOG (fatal) << " Missing cent member" ;
7167+ return {};
7168+ }
7169+ std::string cent = obj[" cent" ].GetString ();
7170+ LOG (debug) << " Centrality type: " << cent;
7171+ if (cent != " kCentFT0C" && cent != " kCentFT0A" && cent != " kCentFT0M" ) {
7172+ LOG (fatal) << " Unsupported centrality type: " << cent;
7173+ return {};
71457174 }
71467175
71477176 // Cut storage
7177+ std::vector<std::pair<double , double >> centBins;
71487178 std::vector<std::pair<double , double >> ptBins;
71497179 std::vector<std::vector<double >> cutsMl;
71507180 std::vector<int > cutDirs;
7181+ std::vector<std::string> labelsFlatBin;
71517182 bool cutDirsFilled = false ;
71527183
7153- for (auto member = obj.MemberBegin (); member != obj.MemberEnd (); ++member ) {
7154- TString key = member ->name .GetString ();
7155- if (!key .Contains (" AddCut " ))
7184+ for (auto centMember = obj.MemberBegin (); centMember != obj.MemberEnd (); ++centMember ) {
7185+ TString centKey = centMember ->name .GetString ();
7186+ if (!centKey .Contains (" AddCentCut " ))
71567187 continue ;
71577188
7158- const auto & cut = member ->value ;
7189+ const auto & centCut = centMember ->value ;
71597190
7160- if (!cut.HasMember (" pTMin" ) || !cut.HasMember (" pTMax" )) {
7161- LOG (fatal) << " Missing pTMin/pTMax in ML cut" ;
7191+ // Centrality info
7192+ if (!centCut.HasMember (" centMin" ) || !centCut.HasMember (" centMax" )) {
7193+ LOG (fatal) << " Missing centMin/centMax in " << centKey;
71627194 return {};
71637195 }
7196+ double centMin = centCut[" centMin" ].GetDouble ();
7197+ double centMax = centCut[" centMax" ].GetDouble ();
71647198
7165- double pTMin = cut[" pTMin" ].GetDouble ();
7166- double pTMax = cut[" pTMax" ].GetDouble ();
7167- ptBins.emplace_back (pTMin, pTMax);
7199+ for (auto ptMember = centCut.MemberBegin (); ptMember != centCut.MemberEnd (); ++ptMember) {
7200+ TString ptKey = ptMember->name .GetString ();
7201+ if (!ptKey.Contains (" AddPtCut" ))
7202+ continue ;
71687203
7169- std::vector<double > binCuts;
7170- bool exclude = false ;
7204+ const auto & ptCut = ptMember->value ;
71717205
7172- for (const auto & sub : cut.GetObject ()) {
7173- TString subKey = sub.name .GetString ();
7174- if (!subKey.Contains (" AddMLCut" ))
7175- continue ;
7206+ // Pt info
7207+ if (!ptCut.HasMember (" pTMin" ) || !ptCut.HasMember (" pTMax" )) {
7208+ LOG (fatal) << " Missing pTMin/pTMax in " << ptKey;
7209+ return {};
7210+ }
7211+
7212+ double ptMin = ptCut[" pTMin" ].GetDouble ();
7213+ double ptMax = ptCut[" pTMax" ].GetDouble ();
7214+
7215+ std::vector<double > binCuts;
7216+ bool exclude = false ;
7217+
7218+ for (auto mlMember = ptCut.MemberBegin (); mlMember != ptCut.MemberEnd (); ++mlMember) {
7219+ TString mlKey = mlMember->name .GetString ();
7220+ if (!mlKey.Contains (" AddMLCut" ))
7221+ continue ;
71767222
7177- const auto & mlcut = sub.value ;
7178- // const char* var = mlcut["var"].GetString();
7179- double cutVal = mlcut.HasMember (" cut" ) ? mlcut[" cut" ].GetDouble () : 0.5 ;
7180- exclude = mlcut.HasMember (" exclude" ) ? mlcut[" exclude" ].GetBool () : false ;
7223+ const auto & mlcut = mlMember->value ;
71817224
7182- binCuts.push_back (cutVal);
7225+ if (!mlcut.HasMember (" cut" )) {
7226+ LOG (fatal) << " Missing cut (score) in " << mlKey;
7227+ return {};
7228+ }
7229+
7230+ double cutVal = mlcut[" cut" ].GetDouble ();
7231+ exclude = mlcut.HasMember (" exclude" ) ? mlcut[" exclude" ].GetBool () : false ;
7232+
7233+ binCuts.push_back (cutVal);
7234+
7235+ if (!cutDirsFilled) {
7236+ cutDirs.push_back (exclude ? 0 : 1 );
7237+ }
7238+ }
71837239
71847240 if (!cutDirsFilled) {
7185- cutDirs.push_back (exclude ? 1 : 0 );
71867241 cutDirsFilled = true ;
71877242 }
7188- }
71897243
7190- cutsMl.push_back (binCuts);
7244+ centBins.emplace_back (centMin, centMax);
7245+ ptBins.emplace_back (ptMin, ptMax);
7246+ cutsMl.push_back (binCuts);
7247+ labelsFlatBin.push_back (Form (" %s_cent%.0f_%.0f_pt%.1f_%.1f" , cent.c_str (), centMin, centMax, ptMin, ptMax));
7248+ LOG (info) << " Added cut for " << Form (" %s_cent%.0f_%.0f_pt%.1f_%.1f" , cent.c_str (), centMin, centMax, ptMin, ptMax) << " with cuts: [" ;
7249+ for (size_t i = 0 ; i < binCuts.size (); ++i) {
7250+ std::cout << binCuts[i];
7251+ if (i != binCuts.size () - 1 )
7252+ std::cout << " , " ;
7253+ }
7254+ std::cout << " ] and direction: " << (exclude ? " CutGreater" : " CutSmaller" ) << std::endl;
7255+ }
71917256 }
71927257
7193- // bin edges
7194- std::vector<double > binsPt;
7195- if (!ptBins.empty ()) {
7196- std::set<double > binEdges;
7197- for (const auto & b : ptBins)
7198- binEdges.insert (b.first );
7199- binEdges.insert (ptBins.back ().second );
7200- binsPt = std::vector<double >(binEdges.begin (), binEdges.end ());
7201- } else {
7202- LOG (fatal) << " No pT bins found in ML cuts" ;
7258+ if (cutDirs.size () != cutsMl[0 ].size ()) {
7259+ LOG (fatal) << " Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0 ].size ()
7260+ << " , cutsMl[0].size() = " << cutDirs.size ();
72037261 return {};
72047262 }
72057263
7206- std::vector<std::string> labelsPt, labelsClass;
7207- for (size_t i = 0 ; i < cutsMl.size (); ++i) {
7208- labelsPt.push_back (Form (" pT%.1f" , binsPt[i]));
7209- }
7264+ std::vector<std::string> labelsClass;
72107265 for (size_t j = 0 ; j < cutsMl[0 ].size (); ++j) {
7211- labelsClass.push_back (Form (" cls %d" , static_cast <int >(j)));
7266+ labelsClass.push_back (Form (" score class %d" , static_cast <int >(j)));
72127267 }
72137268
7269+ size_t nFlatBins = cutsMl.size ();
7270+ std::vector<double > binsMl (nFlatBins + 1 );
7271+ std::iota (binsMl.begin (), binsMl.end (), 0 );
7272+
72147273 // Binary
72157274 if (typeStr == " Binary" ) {
72167275 dqmlcuts::BinaryBdtScoreConfig binaryCfg;
72177276 binaryCfg.inputFeatures = namesInputFeatures;
72187277 binaryCfg.onnxFiles = onnxFileNames;
7219- binaryCfg.binsPt = binsPt;
7278+ binaryCfg.binsCent = centBins;
7279+ binaryCfg.binsPt = ptBins;
7280+ binaryCfg.binsMl = binsMl;
72207281 binaryCfg.cutDirs = cutDirs;
7221- binaryCfg.cutsMl = makeLabeledCutsMl (cutsMl, labelsPt, labelsClass);
7282+ binaryCfg.centType = cent;
7283+ binaryCfg.cutsMl = makeLabeledCutsMl (cutsMl, labelsFlatBin, labelsClass);
72227284
72237285 return binaryCfg;
72247286
7225- // MultiClass
7287+ // MultiClass
72267288 } else if (typeStr == " MultiClass" ) {
72277289 dqmlcuts::MultiClassBdtScoreConfig multiCfg;
72287290 multiCfg.inputFeatures = namesInputFeatures;
72297291 multiCfg.onnxFiles = onnxFileNames;
7230- multiCfg.binsPt = binsPt;
7292+ multiCfg.binsCent = centBins;
7293+ multiCfg.binsPt = ptBins;
7294+ multiCfg.binsMl = binsMl;
72317295 multiCfg.cutDirs = cutDirs;
7232- multiCfg.cutsMl = makeLabeledCutsMl (cutsMl, labelsPt, labelsClass);
7296+ multiCfg.centType = cent;
7297+ multiCfg.cutsMl = makeLabeledCutsMl (cutsMl, labelsFlatBin, labelsClass);
72337298
72347299 return multiCfg;
72357300 }
7236-
7237- LOG (fatal) << " Unsupported classification type: " << typeStr;
7238- return {};
72397301 }
72407302
72417303 return {};
72427304}
72437305
72447306o2::framework::LabeledArray<double > o2::aod::dqmlcuts::makeLabeledCutsMl (const std::vector<std::vector<double >>& cuts,
7245- const std::vector<std::string>& labelsPt ,
7307+ const std::vector<std::string>& labelsflatBin ,
72467308 const std::vector<std::string>& labelsClass)
72477309{
72487310 const size_t nRows = cuts.size ();
@@ -7254,5 +7316,19 @@ o2::framework::LabeledArray<double> o2::aod::dqmlcuts::makeLabeledCutsMl(const s
72547316 }
72557317
72567318 o2::framework::Array2D<double > arr (flat.data (), nRows, nCols);
7257- return o2::framework::LabeledArray<double >(arr, labelsPt , labelsClass);
7319+ return o2::framework::LabeledArray<double >(arr, labelsflatBin , labelsClass);
72587320}
7321+
7322+ int o2::aod::dqmlcuts::getMlBinIndex (double cent, double pt,
7323+ const std::vector<std::pair<double , double >>& binsCent,
7324+ const std::vector<std::pair<double , double >>& binsPt)
7325+ {
7326+ LOG (debug) << " Searching for Ml bin index for cent: " << cent << " , pt: " << pt; // here
7327+ for (size_t i = 0 ; i < binsCent.size (); ++i) {
7328+ if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second ) {
7329+ LOG (debug) << " - Found at index: " << i; // here
7330+ return static_cast <int >(i);
7331+ }
7332+ }
7333+ return -1 ; // not found
7334+ }
0 commit comments