@@ -713,30 +713,42 @@ struct Dilepton {
713713 // fDielectronCut.SetPRangeForITSNsigmaPr(dielectroncuts.cfg_min_p_ITSNsigmaPr, dielectroncuts.cfg_max_p_ITSNsigmaPr);
714714
715715 if (dielectroncuts.cfg_pid_scheme == static_cast <int >(DielectronCut::PIDSchemes::kPIDML )) { // please call this at the end of DefineDileptonCut
716- static constexpr int nClassesMl = 2 ;
717- const std::vector<int > cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
718- const std::vector<std::string> labelsClasses = {" Background" , " Signal" };
719- const uint32_t nBinsMl = dielectroncuts.binsMl .value .size () - 1 ;
720- const std::vector<std::string> labelsBins (nBinsMl, " bin" );
721- double cutsMlArr[nBinsMl][nClassesMl];
722- for (uint32_t i = 0 ; i < nBinsMl; i++) {
723- cutsMlArr[i][0 ] = 0 .;
724- cutsMlArr[i][1 ] = dielectroncuts.cutsMl .value [i];
725- }
726- o2::framework::LabeledArray<double > cutsMl = {cutsMlArr[0 ], nBinsMl, nClassesMl, labelsBins, labelsClasses};
727-
728- mlResponseSingleTrack.configure (dielectroncuts.binsMl .value , cutsMl, cutDirMl, nClassesMl);
729- if (dielectroncuts.loadModelsFromCCDB ) {
730- ccdbApi.init (ccdburl);
731- mlResponseSingleTrack.setModelPathsCCDB (dielectroncuts.onnxFileNames .value , ccdbApi, dielectroncuts.onnxPathsCCDB .value , dielectroncuts.timestampCCDB .value );
732- } else {
733- mlResponseSingleTrack.setModelPathsLocal (dielectroncuts.onnxFileNames .value );
734- }
735- mlResponseSingleTrack.cacheInputFeaturesIndices (dielectroncuts.namesInputFeatures );
736- mlResponseSingleTrack.cacheBinningIndex (dielectroncuts.nameBinningFeature );
737- mlResponseSingleTrack.init (dielectroncuts.enableOptimizations .value );
738-
739- fDielectronCut .SetPIDMlResponse (&mlResponseSingleTrack);
716+ std::vector<float > binsML{};
717+ binsML.reserve (dielectroncuts.binsMl .value .size ());
718+ for (size_t i = 0 ; i < dielectroncuts.binsMl .value .size (); i++) {
719+ binsML.emplace_back (dielectroncuts.binsMl .value [i]);
720+ }
721+ std::vector<float > thresholdsML{};
722+ thresholdsML.reserve (dielectroncuts.cutsMl .value .size ());
723+ for (size_t i = 0 ; i < dielectroncuts.cutsMl .value .size (); i++) {
724+ thresholdsML.emplace_back (dielectroncuts.cutsMl .value [i]);
725+ }
726+ fDielectronCut .SetMLThresholds (binsML, thresholdsML);
727+
728+ // static constexpr int nClassesMl = 2;
729+ // const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
730+ // const std::vector<std::string> labelsClasses = {"Background", "Signal"};
731+ // const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
732+ // const std::vector<std::string> labelsBins(nBinsMl, "bin");
733+ // double cutsMlArr[nBinsMl][nClassesMl];
734+ // for (uint32_t i = 0; i < nBinsMl; i++) {
735+ // cutsMlArr[i][0] = 0.;
736+ // cutsMlArr[i][1] = dielectroncuts.cutsMl.value[i];
737+ // }
738+ // o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
739+
740+ // mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
741+ // if (dielectroncuts.loadModelsFromCCDB) {
742+ // ccdbApi.init(ccdburl);
743+ // mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
744+ // } else {
745+ // mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
746+ // }
747+ // mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
748+ // mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
749+ // mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);
750+
751+ // fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
740752 } // end of PID ML
741753 }
742754
@@ -836,11 +848,11 @@ struct Dilepton {
836848 if constexpr (ev_id == 0 ) {
837849 if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron ) {
838850 if (dielectroncuts.cfg_pid_scheme == static_cast <int >(DielectronCut::PIDSchemes::kPIDML )) {
839- if (!cut.template IsSelectedTrack <false , true >(t1, collision ) || !cut.template IsSelectedTrack <false , true >(t2, collision )) {
851+ if (!cut.template IsSelectedTrack <false >(t1) || !cut.template IsSelectedTrack <false >(t2)) {
840852 return false ;
841853 }
842854 } else { // cut-based
843- if (!cut.template IsSelectedTrack <false , false >(t1) || !cut.template IsSelectedTrack <false , false >(t2)) {
855+ if (!cut.template IsSelectedTrack <false >(t1) || !cut.template IsSelectedTrack <false >(t2)) {
844856 return false ;
845857 }
846858 }
@@ -1377,16 +1389,16 @@ struct Dilepton {
13771389
13781390 } // end of DF
13791391
1380- template <typename TCollision, typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
1381- bool isPairOK (TCollision const & collision, TTrack1 const & t1, TTrack2 const & t2, TCut const & cut, TAllTracks const & tracks)
1392+ template <typename TTrack1, typename TTrack2, typename TCut, typename TAllTracks>
1393+ bool isPairOK (TTrack1 const & t1, TTrack2 const & t2, TCut const & cut, TAllTracks const & tracks)
13821394 {
13831395 if constexpr (pairtype == o2::aod::pwgem::dilepton::utils::pairutil::DileptonPairType::kDielectron ) {
13841396 if (dielectroncuts.cfg_pid_scheme == static_cast <int >(DielectronCut::PIDSchemes::kPIDML )) {
1385- if (!cut.template IsSelectedTrack <false , true >(t1, collision ) || !cut.template IsSelectedTrack <false , true >(t2, collision )) {
1397+ if (!cut.template IsSelectedTrack <false >(t1) || !cut.template IsSelectedTrack <false >(t2)) {
13861398 return false ;
13871399 }
13881400 } else { // cut-based
1389- if (!cut.template IsSelectedTrack <false , false >(t1) || !cut.template IsSelectedTrack <false , false >(t2)) {
1401+ if (!cut.template IsSelectedTrack <false >(t1) || !cut.template IsSelectedTrack <false >(t2)) {
13901402 return false ;
13911403 }
13921404 }
@@ -1470,17 +1482,17 @@ struct Dilepton {
14701482 auto negTracks_per_coll = negTracks.sliceByCached (perCollision, collision.globalIndex (), cache);
14711483
14721484 for (const auto & [pos, neg] : combinations (CombinationsFullIndexPolicy (posTracks_per_coll, negTracks_per_coll))) { // ULS
1473- if (isPairOK (collision, pos, neg, cut, tracks)) {
1485+ if (isPairOK (pos, neg, cut, tracks)) {
14741486 passed_pairIds.emplace_back (std::make_pair (pos.globalIndex (), neg.globalIndex ()));
14751487 }
14761488 }
14771489 for (const auto & [pos1, pos2] : combinations (CombinationsStrictlyUpperIndexPolicy (posTracks_per_coll, posTracks_per_coll))) { // LS++
1478- if (isPairOK (collision, pos1, pos2, cut, tracks)) {
1490+ if (isPairOK (pos1, pos2, cut, tracks)) {
14791491 passed_pairIds.emplace_back (std::make_pair (pos1.globalIndex (), pos2.globalIndex ()));
14801492 }
14811493 }
14821494 for (const auto & [neg1, neg2] : combinations (CombinationsStrictlyUpperIndexPolicy (negTracks_per_coll, negTracks_per_coll))) { // LS--
1483- if (isPairOK (collision, neg1, neg2, cut, tracks)) {
1495+ if (isPairOK (neg1, neg2, cut, tracks)) {
14841496 passed_pairIds.emplace_back (std::make_pair (neg1.globalIndex (), neg2.globalIndex ()));
14851497 }
14861498 }
0 commit comments