2222#define COMMON_TOOLS_PID_PIDTPCMODULE_H_
2323
2424#include " Common/CCDB/ctpRateFetcher.h"
25+ #include " Common/Core/CollisionTypeHelper.h"
2526#include " Common/Core/PID/TPCPIDResponse.h"
2627#include " Common/Core/TableHelper.h"
2728#include " Common/DataModel/PIDResponseTPC.h"
2829#include " Common/TableProducer/PID/pidTPCBase.h"
2930#include " Tools/ML/model.h"
3031
32+ #include < DataFormatsParameters/GRPLHCIFData.h>
3133#include < Framework/AnalysisDataModel.h>
3234#include < Framework/AnalysisHelpers.h>
3335#include < Framework/Configurable.h>
@@ -127,13 +129,13 @@ struct pidTPCConfigurables : o2::framework::ConfigurableGroup {
127129 o2::framework::Configurable<int > useNetworkHe{" useNetworkHe" , 1 , {" Switch for applying neural network on the helium3 mass hypothesis (if network enabled) (set to 0 to disable)" }};
128130 o2::framework::Configurable<int > useNetworkAl{" useNetworkAl" , 1 , {" Switch for applying neural network on the alpha mass hypothesis (if network enabled) (set to 0 to disable)" }};
129131 o2::framework::Configurable<float > networkBetaGammaCutoff{" networkBetaGammaCutoff" , 0.45 , {" Lower value of beta-gamma to override the NN application" }};
130- o2::framework::Configurable<std::string> irSource{ " irSource " , " ZNC hadronic " , " Estimator of the interaction rate (Recommended: pp --> T0VTX, Pb-Pb --> ZNC hadronic) " };
132+ o2::framework::Configurable<std::string> cfgPathGrpLhcIf{ " ccdb-path-grplhcif " , " GLO/Config/GRPLHCIF " , " Path on the CCDB for the GRPLHCIF object " };
131133};
132134
133135// helper getter - FIXME should be separate
134136int getPIDIndex (const int pdgCode) // Get O2 PID index corresponding to MC PDG code
135137{
136- switch (abs (pdgCode)) {
138+ switch (std:: abs (pdgCode)) {
137139 case 11 :
138140 return o2::track::PID::Electron;
139141 case 13 :
@@ -214,6 +216,10 @@ class pidTPCModule
214216 std::vector<int > speciesNetworkFlags = std::vector<int >(9 );
215217 std::string networkVersion;
216218
219+ // To get automatically the proper Hadronic Rate
220+ std::string irSource = " " ;
221+ o2::common::core::CollisionSystemType::collType collsys = o2::common::core::CollisionSystemType::kCollSysUndef ;
222+
217223 // Parametrization configuration
218224 bool useCCDBParam = false ;
219225
@@ -323,6 +329,14 @@ class pidTPCModule
323329 }
324330 LOG (info) << " Successfully retrieved TPC PID object from CCDB for timestamp " << time << " , period " << headers[" LPMProductionTag" ] << " , recoPass " << headers[" RecoPassName" ];
325331 metadata[" RecoPassName" ] = headers[" RecoPassName" ]; // Force pass number for NN request to match retrieved BB
332+ o2::parameters::GRPLHCIFData* grpo = ccdb->template getForTimeStamp <o2::parameters::GRPLHCIFData>(pidTPCopts.cfgPathGrpLhcIf .value , time);
333+ LOG (info) << " collision type::" << CollisionSystemType::getCollisionTypeFromGrp (grpo);
334+ collsys = CollisionSystemType::getCollisionTypeFromGrp (grpo);
335+ if (collsys == CollisionSystemType::kCollSyspp ) {
336+ irSource = std::string (" T0VTX" );
337+ } else {
338+ irSource = std::string (" ZNC hadronic" );
339+ }
326340 response->PrintAll ();
327341 }
328342 }
@@ -368,8 +382,8 @@ class pidTPCModule
368382 } // end init
369383
370384 // __________________________________________________
371- template <typename TCCDB, typename TCCDBApi, typename C, typename M, typename T, typename B>
372- std::vector<float > createNetworkPrediction (TCCDB& ccdb, TCCDBApi& ccdbApi, C const & collisions, M const & mults, T const & tracks, B const & bcs, const size_t size)
385+ template <typename TCCDB, typename TCCDBApi, typename M, typename T, typename B>
386+ std::vector<float > createNetworkPrediction (TCCDB& ccdb, TCCDBApi& ccdbApi, soa::Join<aod::Collisions, aod::EvSels> const & collisions, M const & mults, T const & tracks, B const & bcs, const size_t size)
373387 {
374388
375389 std::vector<float > network_prediction;
@@ -397,6 +411,14 @@ class pidTPCModule
397411 }
398412 LOG (info) << " Successfully retrieved TPC PID object from CCDB for timestamp " << bc.timestamp () << " , period " << headers[" LPMProductionTag" ] << " , recoPass " << headers[" RecoPassName" ];
399413 metadata[" RecoPassName" ] = headers[" RecoPassName" ]; // Force pass number for NN request to match retrieved BB
414+ o2::parameters::GRPLHCIFData* grpo = ccdb->template getForTimeStamp <o2::parameters::GRPLHCIFData>(pidTPCopts.cfgPathGrpLhcIf .value , bc.timestamp ());
415+ LOG (info) << " Collision type::" << CollisionSystemType::getCollisionTypeFromGrp (grpo);
416+ collsys = CollisionSystemType::getCollisionTypeFromGrp (grpo);
417+ if (collsys == CollisionSystemType::kCollSyspp ) {
418+ irSource = std::string (" T0VTX" );
419+ } else {
420+ irSource = std::string (" ZNC hadronic" );
421+ }
400422 response->PrintAll ();
401423 }
402424
@@ -430,11 +452,26 @@ class pidTPCModule
430452 uint64_t counter_track_props = 0 ;
431453 int loop_counter = 0 ;
432454
455+ // To load the Hadronic rate once for each collision
456+ float hadronicRateBegin = 0 .;
457+ std::vector<float > hadronicRateForCollision (collisions.size (), 0 .0f );
458+ size_t i = 0 ;
459+ for (const auto & collision : collisions) {
460+ const auto & bc = collision.template bc_as <B>();
461+ hadronicRateForCollision[i] = mRateFetcher .fetch (ccdb.service , bc.timestamp (), bc.runNumber (), irSource) * 1 .e -3 ;
462+ i++;
463+ }
464+ auto bc = bcs.begin ();
465+ hadronicRateBegin = mRateFetcher .fetch (ccdb.service , bc.timestamp (), bc.runNumber (), irSource) * 1 .e -3 ; // kHz
466+
433467 // Filling a std::vector<float> to be evaluated by the network
434468 // Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
435- for (int i = 0 ; i < 9 ; i++) { // Loop over particle number for which network correction is used
436- float hadronicRate = 0 .;
437- uint64_t timeStamp_bcOld = 0 ;
469+ static constexpr int NParticleTypes = 9 ;
470+ constexpr int kExpectedInputDimensionsNNV2 = 7 ;
471+ constexpr int kExpectedInputDimensionsNNV3 = 8 ;
472+ constexpr auto kNetworkVersionV2 = " 2" ;
473+ constexpr auto kNetworkVersionV3 = " 3" ;
474+ for (int i = 0 ; i < NParticleTypes; i++) { // Loop over particle number for which network correction is used
438475 for (auto const & trk : tracks) {
439476 if (!trk.hasTPC ()) {
440477 continue ;
@@ -450,20 +487,24 @@ class pidTPCModule
450487 track_properties[counter_track_props + 3 ] = o2::track::pid_constants::sMasses [i];
451488 track_properties[counter_track_props + 4 ] = trk.has_collision () ? mults[trk.collisionId ()] / 11000 . : 1 .;
452489 track_properties[counter_track_props + 5 ] = std::sqrt (nNclNormalization / trk.tpcNClsFound ());
453- if (input_dimensions == 7 && networkVersion == " 2 " ) {
490+ if (input_dimensions == kExpectedInputDimensionsNNV2 && networkVersion == kNetworkVersionV2 ) {
454491 track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
455492 }
456- if (input_dimensions == 8 && networkVersion == " 3 " ) {
493+ if (input_dimensions == kExpectedInputDimensionsNNV3 && networkVersion == kNetworkVersionV3 ) {
457494 track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
458495 if (trk.has_collision ()) {
459- auto trk_bc = (collisions.iteratorAt (trk.collisionId ())).template bc_as <B>();
460- if (trk_bc.timestamp () != timeStamp_bcOld) {
461- hadronicRate = mRateFetcher .fetch (ccdb.service , trk_bc.timestamp (), trk_bc.runNumber (), pidTPCopts.irSource .value ) * 1 .e -3 ;
496+ if (collsys == CollisionSystemType::kCollSyspp ) {
497+ track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 1500 .;
498+ } else {
499+ track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 50 .;
462500 }
463- timeStamp_bcOld = trk_bc.timestamp ();
464- track_properties[counter_track_props + 7 ] = hadronicRate / 50 .;
465501 } else {
466- track_properties[counter_track_props + 7 ] = 1 ;
502+ // asign Hadronic Rate at beginning of run if track does not belong to a collision
503+ if (collsys == CollisionSystemType::kCollSyspp ) {
504+ track_properties[counter_track_props + 7 ] = hadronicRateBegin / 1500 .;
505+ } else {
506+ track_properties[counter_track_props + 7 ] = hadronicRateBegin / 50 .;
507+ }
467508 }
468509 }
469510 counter_track_props += input_dimensions;
@@ -526,22 +567,24 @@ class pidTPCModule
526567
527568 float nSigma = -999 .f ;
528569 float bg = trk.tpcInnerParam () / o2::track::pid_constants::sMasses [pid]; // estimated beta-gamma for network cutoff
570+ constexpr int kNumOutputNodesSymmetricSigma = 2 ;
571+ constexpr int kNumOutputNodesAsymmetricSigma = 3 ;
529572 if (pidTPCopts.useNetworkCorrection && speciesNetworkFlags[pid] && trk.has_collision () && bg > pidTPCopts.networkBetaGammaCutoff ) {
530573
531574 // Here comes the application of the network. The output--dimensions of the network determine the application: 1: mean, 2: sigma, 3: sigma asymmetric
532575 // For now only the option 2: sigma will be used. The other options are kept if there would be demand later on
533576 if (network.getNumOutputNodes () == 1 ) { // Expected mean correction; no sigma correction
534577 nSigma = (tpcSignal - network_prediction[count_tracks + tracksForNet_size * pid] * expSignal) / expSigma;
535- } else if (network.getNumOutputNodes () == 2 ) { // Symmetric sigma correction
536- expSigma = (network_prediction[2 * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[2 * (count_tracks + tracksForNet_size * pid)]) * expSignal;
537- nSigma = (tpcSignal / expSignal - network_prediction[2 * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[2 * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[2 * (count_tracks + tracksForNet_size * pid)]);
538- } else if (network.getNumOutputNodes () == 3 ) { // Asymmetric sigma corection
539- if (tpcSignal / expSignal >= network_prediction[3 * (count_tracks + tracksForNet_size * pid)]) {
540- expSigma = (network_prediction[3 * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[3 * (count_tracks + tracksForNet_size * pid)]) * expSignal;
541- nSigma = (tpcSignal / expSignal - network_prediction[3 * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[3 * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[3 * (count_tracks + tracksForNet_size * pid)]);
578+ } else if (network.getNumOutputNodes () == kNumOutputNodesSymmetricSigma ) { // Symmetric sigma correction
579+ expSigma = (network_prediction[kNumOutputNodesSymmetricSigma * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[kNumOutputNodesSymmetricSigma * (count_tracks + tracksForNet_size * pid)]) * expSignal;
580+ nSigma = (tpcSignal / expSignal - network_prediction[kNumOutputNodesSymmetricSigma * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[kNumOutputNodesSymmetricSigma * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[kNumOutputNodesSymmetricSigma * (count_tracks + tracksForNet_size * pid)]);
581+ } else if (network.getNumOutputNodes () == kNumOutputNodesAsymmetricSigma ) { // Asymmetric sigma corection
582+ if (tpcSignal / expSignal >= network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)]) {
583+ expSigma = (network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)]) * expSignal;
584+ nSigma = (tpcSignal / expSignal - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid) + 1 ] - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)]);
542585 } else {
543- expSigma = (network_prediction[3 * (count_tracks + tracksForNet_size * pid)] - network_prediction[3 * (count_tracks + tracksForNet_size * pid) + 2 ]) * expSignal;
544- nSigma = (tpcSignal / expSignal - network_prediction[3 * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[3 * (count_tracks + tracksForNet_size * pid)] - network_prediction[3 * (count_tracks + tracksForNet_size * pid) + 2 ]);
586+ expSigma = (network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)] - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid) + 2 ]) * expSignal;
587+ nSigma = (tpcSignal / expSignal - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)]) / (network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid)] - network_prediction[kNumOutputNodesAsymmetricSigma * (count_tracks + tracksForNet_size * pid) + 2 ]);
545588 }
546589 } else {
547590 LOGF (fatal, " Network output-dimensions incompatible!" );
@@ -556,8 +599,8 @@ class pidTPCModule
556599 };
557600
558601 // __________________________________________________
559- template <typename TCCDB, typename TCCDBApi, typename TBCs, typename TCollisions, typename TTracks, typename TTracksQA, typename TProducts>
560- void process (TCCDB& ccdb, TCCDBApi& ccdbApi, TBCs const & bcs, TCollisions const & cols, TTracks const & tracks, TTracksQA const & tracksQA, TProducts& products)
602+ template <typename TCCDB, typename TCCDBApi, typename TBCs, typename TTracks, typename TTracksQA, typename TProducts>
603+ void process (TCCDB& ccdb, TCCDBApi& ccdbApi, TBCs const & bcs, soa::Join<aod::Collisions, aod::EvSels> const & cols, TTracks const & tracks, TTracksQA const & tracksQA, TProducts& products)
561604 {
562605 if (tracks.size () == 0 ) {
563606 return ; // empty protection
@@ -628,13 +671,29 @@ class pidTPCModule
628671 // _______________________________________
629672 // process tracksQA in case present
630673 std::vector<int64_t > indexTrack2TrackQA;
674+ indexTrack2TrackQA.clear ();
675+ indexTrack2TrackQA.resize (outTable_size,-1 );
631676 if constexpr (soa::is_table<TTracksQA>) {
632677 for (const auto & trackQA : tracksQA) {
633678 indexTrack2TrackQA[trackQA.trackId ()] = trackQA.globalIndex ();
634679 }
635680 }
636681 // _______________________________________
637682
683+ // Fill Hadronic rate per collision in case CorrectedDEdx is requested
684+ std::vector<float > hadronicRateForCollision (cols.size (), 0 .0f );
685+ float hadronicRateBegin = 0 .0f ;
686+ if (pidTPCopts.useCorrecteddEdx ) {
687+ size_t i = 0 ;
688+ for (const auto & collision : cols) {
689+ const auto & bc = collision.template bc_as <aod::BCsWithTimestamps>();
690+ hadronicRateForCollision[i] = mRateFetcher .fetch (ccdb.service , bc.timestamp (), bc.runNumber (), irSource) * 1 .e -3 ;
691+ i++;
692+ }
693+ auto bc = bcs.begin ();
694+ hadronicRateBegin = mRateFetcher .fetch (ccdb.service , bc.timestamp (), bc.runNumber (), irSource) * 1 .e -3 ; // kHz
695+ }
696+
638697 for (auto const & trk : tracks) {
639698 // get the TPC signal to be used in the PID
640699 float tpcSignalToEvaluatePID = trk.tpcSignal ();
@@ -662,24 +721,25 @@ class pidTPCModule
662721 int occupancy;
663722 if (trk.has_collision ()) {
664723 auto collision = cols.iteratorAt (trk.collisionId ());
665- auto bc = collision.template bc_as <aod::BCsWithTimestamps>();
666- const int runnumber = bc.runNumber ();
667- hadronicRate = mRateFetcher .fetch (ccdb.service , bc.timestamp (), runnumber, " ZNC hadronic" ) * 1 .e -3 ; // kHz
724+ hadronicRate = hadronicRateForCollision[trk.collisionId ()];
668725 occupancy = collision.trackOccupancyInTimeRange ();
669726 } else {
670- auto bc = bcs.begin ();
671- const int runnumber = bc.runNumber ();
672- hadronicRate = mRateFetcher .fetch (ccdb.service , bc.timestamp (), runnumber, " ZNC hadronic" ) * 1 .e -3 ; // kHz
727+ hadronicRate = hadronicRateBegin;
673728 occupancy = 0 ;
674729 }
675730
731+ constexpr float kExpectedTPCSignalMIP = 50 .0f ;
732+ constexpr float kMaxAllowedRatio = 1 .05f ;
733+ constexpr float kMinAllowedRatio = 0 .05f ;
734+ constexpr float kMaxAllowedOcc = 12 .0f ;
735+
676736 float fTPCSignal = tpcSignalToEvaluatePID;
677737 float fNormMultTPC = multTPC / 11000 .;
678738
679739 float fTrackOccN = occupancy / 1000 .;
680740 float fOccTPCN = fNormMultTPC * 10 ; // (fNormMultTPC*10).clip(0,12)
681- if (fOccTPCN > 12 )
682- fOccTPCN = 12 ;
741+ if (fOccTPCN > kMaxAllowedOcc )
742+ fOccTPCN = kMaxAllowedOcc ;
683743 else if (fOccTPCN < 0 )
684744 fOccTPCN = 0 ;
685745
@@ -688,11 +748,11 @@ class pidTPCModule
688748 float a1pt = std::abs (trk.signed1Pt ());
689749 float a1pt2 = a1pt * a1pt;
690750 float atgl = std::abs (trk.tgl ());
691- float mbb0R = 50 / fTPCSignal ;
692- if (mbb0R > 1.05 )
693- mbb0R = 1.05 ;
694- else if (mbb0R < 0.05 )
695- mbb0R = 0.05 ;
751+ float mbb0R = kExpectedTPCSignalMIP / fTPCSignal ;
752+ if (mbb0R > kMaxAllowedRatio )
753+ mbb0R = kMaxAllowedRatio ;
754+ else if (mbb0R < kMinAllowedRatio )
755+ mbb0R = kMinAllowedRatio ;
696756 // float mbb0R = max(0.05, min(50 / fTPCSignal, 1.05));
697757 float a1ptmbb0R = a1pt * mbb0R;
698758 float atglmbb0R = atgl * mbb0R;
@@ -702,11 +762,11 @@ class pidTPCModule
702762
703763 float fTPCSignalN_CR0 = str_dedx_correction.fReal_fTPCSignalN (vec_occu, vec_track);
704764
705- float mbb0R1 = 50 / (fTPCSignal / fTPCSignalN_CR0 );
706- if (mbb0R1 > 1.05 )
707- mbb0R1 = 1.05 ;
708- else if (mbb0R1 < 0.05 )
709- mbb0R1 = 0.05 ;
765+ float mbb0R1 = kExpectedTPCSignalMIP / (fTPCSignal / fTPCSignalN_CR0 );
766+ if (mbb0R1 > kMaxAllowedRatio )
767+ mbb0R1 = kMaxAllowedRatio ;
768+ else if (mbb0R1 < kMinAllowedRatio )
769+ mbb0R1 = kMinAllowedRatio ;
710770
711771 std::vector<float > vec_track1 = {mbb0R1, a1pt, atgl, atgl * mbb0R1, a1pt * mbb0R1, side, a1pt2};
712772 float fTPCSignalN_CR1 = str_dedx_correction.fReal_fTPCSignalN (vec_occu, vec_track1);
@@ -738,6 +798,14 @@ class pidTPCModule
738798 }
739799 }
740800 LOG (info) << " Successfully retrieved TPC PID object from CCDB for timestamp " << bc.timestamp () << " , period " << headers[" LPMProductionTag" ] << " , recoPass " << headers[" RecoPassName" ];
801+ o2::parameters::GRPLHCIFData* grpo = ccdb->template getForTimeStamp <o2::parameters::GRPLHCIFData>(pidTPCopts.cfgPathGrpLhcIf .value , bc.timestamp ());
802+ LOG (info) << " Collisions type::" << CollisionSystemType::getCollisionTypeFromGrp (grpo);
803+ collsys = CollisionSystemType::getCollisionTypeFromGrp (grpo);
804+ if (collsys == CollisionSystemType::kCollSyspp ) {
805+ irSource = std::string (" T0VTX" );
806+ } else {
807+ irSource = std::string (" ZNC hadronic" );
808+ }
741809 response->PrintAll ();
742810 }
743811
0 commit comments