1212// / \brief write relevant information about primary electrons.
1313// / \author daiki.sekihata@cern.ch
1414
15- #include < unordered_map>
16- #include < string>
17- #include < vector>
18- #include < utility>
15+ #include " PWGEM/Dilepton/DataModel/dileptonTables.h"
16+ #include " PWGEM/Dilepton/Utils/MlResponseO2Track.h"
17+ #include " PWGEM/Dilepton/Utils/PairUtilities.h"
1918
20- #include " Math/Vector4D.h"
21- #include " Framework/runDataProcessing.h"
22- #include " Framework/AnalysisTask.h"
23- #include " Framework/AnalysisDataModel.h"
24- #include " DetectorsBase/Propagator.h"
25- #include " DetectorsBase/GeometryManager.h"
26- #include " DataFormatsParameters/GRPObject.h"
27- #include " DataFormatsParameters/GRPMagField.h"
28- #include " DataFormatsCalibration/MeanVertexObject.h"
29- #include " CCDB/BasicCCDBManager.h"
19+ #include " Common/Core/TableHelper.h"
3020#include " Common/Core/trackUtilities.h"
31- #include " CommonConstants/PhysicsConstants.h"
3221#include " Common/DataModel/CollisionAssociationTables.h"
33- #include " Common/Core/TableHelper.h"
3422#include " Common/DataModel/PIDResponse.h"
3523#include " Common/DataModel/PIDResponseITS.h"
24+ #include " Tools/ML/MlResponse.h"
3625
37- #include " PWGEM/Dilepton/DataModel/dileptonTables.h"
38- #include " PWGEM/Dilepton/Utils/PairUtilities.h"
26+ #include " CCDB/BasicCCDBManager.h"
27+ #include " CommonConstants/PhysicsConstants.h"
28+ #include " DataFormatsCalibration/MeanVertexObject.h"
29+ #include " DataFormatsParameters/GRPMagField.h"
30+ #include " DataFormatsParameters/GRPObject.h"
31+ #include " DetectorsBase/GeometryManager.h"
32+ #include " DetectorsBase/Propagator.h"
33+ #include " Framework/AnalysisDataModel.h"
34+ #include " Framework/AnalysisTask.h"
35+ #include " Framework/runDataProcessing.h"
36+
37+ #include " Math/Vector4D.h"
38+
39+ #include < string>
40+ #include < unordered_map>
41+ #include < utility>
42+ #include < vector>
3943
4044using namespace o2 ;
4145using namespace o2 ::soa;
@@ -70,7 +74,7 @@ struct skimmerPrimaryElectron {
7074 // Operation and minimisation criteria
7175 Configurable<bool > fillQAHistogram{" fillQAHistogram" , false , " flag to fill QA histograms" };
7276 Configurable<float > d_bz_input{" d_bz_input" , -999 , " bz field in kG, -999 is automatic" };
73- Configurable<int > min_ncluster_tpc{" min_ncluster_tpc" , 10 , " min ncluster tpc" };
77+ Configurable<int > min_ncluster_tpc{" min_ncluster_tpc" , 0 , " min ncluster tpc" };
7478 Configurable<int > mincrossedrows{" mincrossedrows" , 70 , " min. crossed rows" };
7579 Configurable<float > min_tpc_cr_findable_ratio{" min_tpc_cr_findable_ratio" , 0.8 , " min. TPC Ncr/Nf ratio" };
7680 Configurable<int > min_ncluster_its{" min_ncluster_its" , 4 , " min ncluster its" };
@@ -79,8 +83,8 @@ struct skimmerPrimaryElectron {
7983 Configurable<float > maxchi2its{" maxchi2its" , 6.0 , " max. chi2/NclsITS" };
8084 Configurable<float > minpt{" minpt" , 0.15 , " min pt for track" };
8185 Configurable<float > maxeta{" maxeta" , 0.9 , " eta acceptance" };
82- Configurable<float > dca_xy_max{" dca_xy_max" , 0 . 3f , " max DCAxy in cm" };
83- Configurable<float > dca_z_max{" dca_z_max" , 0 . 3f , " max DCAz in cm" };
86+ Configurable<float > dca_xy_max{" dca_xy_max" , 1.0 , " max DCAxy in cm" };
87+ Configurable<float > dca_z_max{" dca_z_max" , 1.0 , " max DCAz in cm" };
8488 Configurable<float > dca_3d_sigma_max{" dca_3d_sigma_max" , 1e+10 , " max DCA 3D in sigma" };
8589 Configurable<float > minTPCNsigmaEl{" minTPCNsigmaEl" , -2.5 , " min. TPC n sigma for electron inclusion" };
8690 Configurable<float > maxTPCNsigmaEl{" maxTPCNsigmaEl" , 3.5 , " max. TPC n sigma for electron inclusion" };
@@ -96,7 +100,20 @@ struct skimmerPrimaryElectron {
96100 Configurable<float > max_pin_for_pion_rejection{" max_pin_for_pion_rejection" , 0.5 , " pion rejection is applied below this pin" };
97101 Configurable<float > max_frac_shared_clusters_tpc{" max_frac_shared_clusters_tpc" , 999 .f , " max fraction of shared clusters in TPC" };
98102
103+ // configuration for PID ML
104+ Configurable<bool > usePIDML{" usePIDML" , false , " Flag to use PID ML" };
105+ Configurable<std::vector<std::string>> onnxFileNames{" onnxFileNames" , std::vector<std::string>{" filename" }, " ONNX file names for each bin (if not from CCDB full path)" };
106+ Configurable<std::vector<std::string>> onnxPathsCCDB{" onnxPathsCCDB" , std::vector<std::string>{" path" }, " Paths of models on CCDB" };
107+ Configurable<std::vector<double >> binsMl{" binsMl" , std::vector<double >{-999999 ., 999999 .}, " Bin limits for ML application" };
108+ Configurable<std::vector<double >> cutsMl{" cutsMl" , std::vector<double >{0.95 }, " ML cuts per bin" };
109+ Configurable<std::vector<std::string>> namesInputFeatures{" namesInputFeatures" , std::vector<std::string>{" feature" }, " Names of ML model input features" };
110+ Configurable<std::string> nameBinningFeature{" nameBinningFeature" , " pt" , " Names of ML model binning feature" };
111+ Configurable<int64_t > timestampCCDB{" timestampCCDB" , -1 , " timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp" };
112+ Configurable<bool > loadModelsFromCCDB{" loadModelsFromCCDB" , false , " Flag to enable or disable the loading of models from CCDB" };
113+ Configurable<bool > enableOptimizations{" enableOptimizations" , false , " Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)" };
114+
99115 HistogramRegistry fRegistry {" output" , {}, OutputObjHandlingPolicy::AnalysisObject, false , false };
116+ o2::analysis::MlResponseO2Track<float > mlResponseSingleTrack;
100117
101118 int mRunNumber ;
102119 float d_bz;
@@ -106,6 +123,7 @@ struct skimmerPrimaryElectron {
106123 o2::dataformats::VertexBase mVtx ;
107124 const o2::dataformats::MeanVertexObject* mMeanVtx = nullptr ;
108125 o2::base::MatLayerCylSet* lut = nullptr ;
126+ o2::ccdb::CcdbApi ccdbApi;
109127
110128 void init (InitContext&)
111129 {
@@ -116,6 +134,7 @@ struct skimmerPrimaryElectron {
116134 ccdb->setCaching (true );
117135 ccdb->setLocalObjectValidityChecking ();
118136 ccdb->setFatalWhenNull (false );
137+ ccdbApi.init (ccdburl);
119138
120139 if (fillQAHistogram) {
121140 fRegistry .add (" Track/hPt" , " pT;p_{T} (GeV/c)" , kTH1F , {{1000 , 0 .0f , 10 }}, false );
@@ -154,6 +173,31 @@ struct skimmerPrimaryElectron {
154173 fRegistry .add (" Track/hITSNsigmaKa" , " ITS n sigma ka;p_{pv} (GeV/c);n #sigma_{K}^{ITS}" , kTH2F , {{1000 , 0 , 10 }, {100 , -5 , +5 }}, false );
155174 fRegistry .add (" Track/hITSNsigmaPr" , " ITS n sigma pr;p_{pv} (GeV/c);n #sigma_{p}^{ITS}" , kTH2F , {{1000 , 0 , 10 }, {100 , -5 , +5 }}, false );
156175 }
176+
177+ if (usePIDML) {
178+ static constexpr int nClassesMl = 2 ;
179+ const std::vector<int > cutDirMl = {o2::cuts_ml::CutGreater, o2::cuts_ml::CutNot};
180+ const std::vector<std::string> labelsClasses = {" Signal" , " Background" };
181+ const uint32_t nBinsMl = binsMl.value .size () - 1 ;
182+ const std::vector<std::string> labelsBins (nBinsMl, " bin" );
183+ double cutsMlArr[nBinsMl][nClassesMl];
184+ for (uint32_t i = 0 ; i < nBinsMl; i++) {
185+ cutsMlArr[i][0 ] = cutsMl.value [i];
186+ cutsMlArr[i][1 ] = 0 .;
187+ }
188+ o2::framework::LabeledArray<double > cutsMl = {cutsMlArr[0 ], nBinsMl, nClassesMl, labelsBins, labelsClasses};
189+
190+ mlResponseSingleTrack.configure (binsMl.value , cutsMl, cutDirMl, nClassesMl);
191+ if (loadModelsFromCCDB) {
192+ ccdbApi.init (ccdburl);
193+ mlResponseSingleTrack.setModelPathsCCDB (onnxFileNames.value , ccdbApi, onnxPathsCCDB.value , timestampCCDB.value );
194+ } else {
195+ mlResponseSingleTrack.setModelPathsLocal (onnxFileNames.value );
196+ }
197+ mlResponseSingleTrack.cacheInputFeaturesIndices (namesInputFeatures);
198+ mlResponseSingleTrack.cacheBinningIndex (nameBinningFeature);
199+ mlResponseSingleTrack.init (enableOptimizations.value );
200+ } // end of PID ML
157201 }
158202
159203 void initCCDB (aod::BCsWithTimestamps::iterator const & bc)
@@ -299,10 +343,32 @@ struct skimmerPrimaryElectron {
299343 return true ;
300344 }
301345
302- template <typename TTrack>
303- bool isElectron (TTrack const & track)
346+ template <typename TCollision, typename TTrack>
347+ bool isElectron (TCollision const & collision, TTrack const & track)
304348 {
305- return isElectron_TPChadrej (track) || isElectron_TOFreq (track);
349+ if (usePIDML) {
350+ if (track.tpcNSigmaEl () < minTPCNsigmaEl || maxTPCNsigmaEl < track.tpcNSigmaEl ()) {
351+ return false ;
352+ }
353+ if (track.hasTOF () && (maxTOFNsigmaEl < std::fabs (track.tofNSigmaEl ()))) {
354+ return false ;
355+ }
356+
357+ // return false;
358+ o2::dataformats::DCA mDcaInfoCov ;
359+ mDcaInfoCov .set (999 , 999 , 999 , 999 , 999 );
360+ auto trackParCov = getTrackParCov (track);
361+ trackParCov.setPID (o2::track::PID::Electron);
362+ mVtx .setPos ({collision.posX (), collision.posY (), collision.posZ ()});
363+ mVtx .setCov (collision.covXX (), collision.covXY (), collision.covYY (), collision.covXZ (), collision.covYZ (), collision.covZZ ());
364+ o2::base::Propagator::Instance ()->propagateToDCABxByBz (mVtx , trackParCov, 2 .f , matCorr, &mDcaInfoCov );
365+
366+ std::vector<float > inputFeatures = mlResponseSingleTrack.getInputFeatures (track, trackParCov, collision);
367+ float binningFeature = mlResponseSingleTrack.getBinningFeature (track, trackParCov, collision);
368+ return mlResponseSingleTrack.isSelectedMl (inputFeatures, binningFeature);
369+ } else {
370+ return isElectron_TPChadrej (track) || isElectron_TOFreq (track);
371+ }
306372 }
307373
308374 template <typename TTrack>
@@ -460,7 +526,7 @@ struct skimmerPrimaryElectron {
460526
461527 auto tracks_per_coll = tracksWithITSPid.sliceBy (perCol, collision.globalIndex ());
462528 for (const auto & track : tracks_per_coll) {
463- if (!checkTrack<false >(collision, track) || !isElectron (track)) {
529+ if (!checkTrack<false >(collision, track) || !isElectron (collision, track)) {
464530 continue ;
465531 }
466532 fillTrackTable (collision, track);
@@ -491,7 +557,7 @@ struct skimmerPrimaryElectron {
491557 for (const auto & trackId : trackIdsThisCollision) {
492558 // auto track = trackId.template track_as<MyTracks>();
493559 auto track = tracksWithITSPid.rawIteratorAt (trackId.trackId ());
494- if (!checkTrack<false >(collision, track) || !isElectron (track)) {
560+ if (!checkTrack<false >(collision, track) || !isElectron (collision, track)) {
495561 continue ;
496562 }
497563 fillTrackTable (collision, track);
@@ -522,7 +588,7 @@ struct skimmerPrimaryElectron {
522588
523589 auto tracks_per_coll = tracksWithITSPid.sliceBy (perCol, collision.globalIndex ());
524590 for (const auto & track : tracks_per_coll) {
525- if (!checkTrack<false >(collision, track) || !isElectron (track)) {
591+ if (!checkTrack<false >(collision, track) || !isElectron (collision, track)) {
526592 continue ;
527593 }
528594 fillTrackTable (collision, track);
@@ -556,7 +622,7 @@ struct skimmerPrimaryElectron {
556622 for (const auto & trackId : trackIdsThisCollision) {
557623 // auto track = trackId.template track_as<MyTracks>();
558624 auto track = tracksWithITSPid.rawIteratorAt (trackId.trackId ());
559- if (!checkTrack<false >(collision, track) || !isElectron (track)) {
625+ if (!checkTrack<false >(collision, track) || !isElectron (collision, track)) {
560626 continue ;
561627 }
562628 fillTrackTable (collision, track);
@@ -591,7 +657,7 @@ struct skimmerPrimaryElectron {
591657
592658 auto tracks_per_coll = tracksWithITSPid.sliceBy (perCol, collision.globalIndex ());
593659 for (const auto & track : tracks_per_coll) {
594- if (!checkTrack<true >(collision, track) || !isElectron (track)) {
660+ if (!checkTrack<true >(collision, track) || !isElectron (collision, track)) {
595661 continue ;
596662 }
597663 fillTrackTable (collision, track);
@@ -624,7 +690,7 @@ struct skimmerPrimaryElectron {
624690 for (const auto & trackId : trackIdsThisCollision) {
625691 // auto track = trackId.template track_as<MyTracksMC>();
626692 auto track = tracksWithITSPid.rawIteratorAt (trackId.trackId ());
627- if (!checkTrack<true >(collision, track) || !isElectron (track)) {
693+ if (!checkTrack<true >(collision, track) || !isElectron (collision, track)) {
628694 continue ;
629695 }
630696 fillTrackTable (collision, track);
0 commit comments