3838#include " Common/Core/trackUtilities.h"
3939#include " PWGJE/Core/JetUtilities.h"
4040
41- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
42- #include < onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
43- #else
44- #include < onnxruntime_cxx_api.h>
45- #endif
46-
47- using namespace o2 ::constants::physics;
48-
4941enum JetTaggingSpecies {
5042 none = 0 ,
5143 charm = 1 ,
@@ -71,203 +63,60 @@ namespace jettaggingutilities
7163const int cmTomum = 10000 ; // using cm -> #mum for impact parameter (dca)
7264
7365struct BJetParams {
74- float mJetpT = 0.0 ;
75- float mJetEta = 0.0 ;
76- float mJetPhi = 0.0 ;
77- int mNTracks = -1 ;
78- int mNSV = -1 ;
79- float mJetMass = 0.0 ;
66+ float JetpT = 0.0 ;
67+ float JetEta = 0.0 ;
68+ float JetPhi = 0.0 ;
69+ int NTracks = -1 ;
70+ int NSV = -1 ;
71+ float JetMass = 0.0 ;
8072};
8173
8274struct BJetTrackParams {
83- double mTrackpT = 0.0 ;
84- double mTrackEta = 0.0 ;
85- double mDotProdTrackJet = 0.0 ;
86- double mDotProdTrackJetOverJet = 0.0 ;
87- double mDeltaRJetTrack = 0.0 ;
88- double mSignedIP2D = 0.0 ;
89- double mSignedIP2DSign = 0.0 ;
90- double mSignedIP3D = 0.0 ;
91- double mSignedIP3DSign = 0.0 ;
92- double mMomFraction = 0.0 ;
93- double mDeltaRTrackVertex = 0.0 ;
75+ double TrackpT = 0.0 ;
76+ double TrackEta = 0.0 ;
77+ double DotProdTrackJet = 0.0 ;
78+ double DotProdTrackJetOverJet = 0.0 ;
79+ double DeltaRJetTrack = 0.0 ;
80+ double SignedIP2D = 0.0 ;
81+ double SignedIP2DSign = 0.0 ;
82+ double SignedIP3D = 0.0 ;
83+ double SignedIP3DSign = 0.0 ;
84+ double MomFraction = 0.0 ;
85+ double DeltaRTrackVertex = 0.0 ;
86+ double TrackPhi = 0.0 ;
87+ double TrackCharge = 0.0 ;
88+ double TrackITSChi2NCl = 0.0 ;
89+ double TrackTPCChi2NCl = 0.0 ;
90+ double TrackITSNCls = 0.0 ;
91+ double TrackTPCNCls = 0.0 ;
92+ double TrackTPCNCrossedRows = 0.0 ;
93+ int TrackOrigin = -1 ;
94+ int TrackVtxIndex = -1 ;
9495};
9596
9697struct BJetSVParams {
97- double mSVpT = 0.0 ;
98- double mDeltaRSVJet = 0.0 ;
99- double mSVMass = 0.0 ;
100- double mSVfE = 0.0 ;
101- double mIPXY = 0.0 ;
102- double mCPA = 0.0 ;
103- double mChi2PCA = 0.0 ;
104- double mDispersion = 0.0 ;
105- double mDecayLength2D = 0.0 ;
106- double mDecayLength2DError = 0.0 ;
107- double mDecayLength3D = 0.0 ;
108- double mDecayLength3DError = 0.0 ;
109- };
110-
111- // ONNX Runtime tensor (Ort::Value) allocator for using customized inputs of ML models.
112- class TensorAllocator
113- {
114- protected:
115- #if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
116- Ort::MemoryInfo mem_info;
117- #endif
118- public:
119- TensorAllocator ()
120- #if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
121- : mem_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
122- #endif
123- {
124- }
125- ~TensorAllocator () = default ;
126- template <typename T>
127- Ort::Value createTensor (std::vector<T>& input, std::vector<int64_t >& inputShape)
128- {
129- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
130- return Ort::Experimental::Value::CreateTensor<T>(input.data (), input.size (), inputShape);
131- #else
132- return Ort::Value::CreateTensor<T>(mem_info, input.data (), input.size (), inputShape.data (), inputShape.size ());
133- #endif
134- }
135- };
136-
137- // TensorAllocator for GNN b-jet tagger
138- class GNNBjetAllocator : public TensorAllocator
139- {
140- private:
141- int64_t nJetFeat;
142- int64_t nTrkFeat;
143- int64_t nFlav;
144- int64_t nTrkOrigin;
145- int64_t maxNNodes;
146-
147- std::vector<float > tfJetMean;
148- std::vector<float > tfJetStdev;
149- std::vector<float > tfTrkMean;
150- std::vector<float > tfTrkStdev;
151-
152- std::vector<std::vector<int64_t >> edgesList;
153-
154- // Jet feature normalization
155- template <typename T>
156- T jetFeatureTransform (T feat, int idx) const
157- {
158- return (feat - tfJetMean[idx]) / tfJetStdev[idx];
159- }
160-
161- // Track feature normalization
162- template <typename T>
163- T trkFeatureTransform (T feat, int idx) const
164- {
165- return (feat - tfTrkMean[idx]) / tfTrkStdev[idx];
166- }
167-
168- // Edge input of GNN (fully-connected graph)
169- void setEdgesList (void )
170- {
171- for (int64_t nNodes = 0 ; nNodes <= maxNNodes; ++nNodes) {
172- std::vector<std::pair<int64_t , int64_t >> edges;
173- // Generate all permutations of (i, j) where i != j
174- for (int64_t i = 0 ; i < nNodes; ++i) {
175- for (int64_t j = 0 ; j < nNodes; ++j) {
176- if (i != j) {
177- edges.emplace_back (i, j);
178- }
179- }
180- }
181- // Add self-loops (i, i)
182- for (int64_t i = 0 ; i < nNodes; ++i) {
183- edges.emplace_back (i, i);
184- }
185- // Flatten
186- std::vector<int64_t > flattenedEdges;
187- for (const auto & edge : edges) {
188- flattenedEdges.push_back (edge.first );
189- }
190- for (const auto & edge : edges) {
191- flattenedEdges.push_back (edge.second );
192- }
193- edgesList.push_back (flattenedEdges);
194- }
195- }
196-
197- // Replace NaN in a vector into value
198- template <typename T>
199- static int replaceNaN (std::vector<T>& vec, T value)
200- {
201- int numNaN = 0 ;
202- for (auto & el : vec) {
203- if (std::isnan (el)) {
204- el = value;
205- ++numNaN;
206- }
207- }
208- return numNaN;
209- }
210-
211- public:
212- GNNBjetAllocator () : TensorAllocator(), nJetFeat(4 ), nTrkFeat(13 ), nFlav(3 ), nTrkOrigin(5 ), maxNNodes(40 ) {}
213- GNNBjetAllocator (int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float >& tfJetMean, std::vector<float >& tfJetStdev, std::vector<float >& tfTrkMean, std::vector<float >& tfTrkStdev, int64_t maxNNodes = 40 )
214- : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
215- {
216- setEdgesList ();
217- }
218- ~GNNBjetAllocator () = default ;
219-
220- // Copy operator for initializing GNNBjetAllocator using Configurable values
221- GNNBjetAllocator& operator =(const GNNBjetAllocator& other)
222- {
223- nJetFeat = other.nJetFeat ;
224- nTrkFeat = other.nTrkFeat ;
225- nFlav = other.nFlav ;
226- nTrkOrigin = other.nTrkOrigin ;
227- maxNNodes = other.maxNNodes ;
228- tfJetMean = other.tfJetMean ;
229- tfJetStdev = other.tfJetStdev ;
230- tfTrkMean = other.tfTrkMean ;
231- tfTrkStdev = other.tfTrkStdev ;
232- setEdgesList ();
233- return *this ;
234- }
235-
236- // Allocate & Return GNN input tensors (std::vector<Ort::Value>)
237- template <typename T>
238- void getGNNInput (std::vector<T>& jetFeat, std::vector<std::vector<T>>& trkFeat, std::vector<T>& feat, std::vector<Ort::Value>& gnnInput)
239- {
240- int64_t nNodes = trkFeat.size ();
241-
242- std::vector<int64_t > edgesShape{2 , nNodes * nNodes};
243- gnnInput.emplace_back (createTensor (edgesList[nNodes], edgesShape));
244-
245- std::vector<int64_t > featShape{nNodes, nJetFeat + nTrkFeat};
246-
247- int numNaN = replaceNaN (jetFeat, 0 .f );
248- for (auto & aTrkFeat : trkFeat) {
249- for (size_t i = 0 ; i < jetFeat.size (); ++i)
250- feat.push_back (jetFeatureTransform (jetFeat[i], i));
251- numNaN += replaceNaN (aTrkFeat, 0 .f );
252- for (size_t i = 0 ; i < aTrkFeat.size (); ++i)
253- feat.push_back (trkFeatureTransform (aTrkFeat[i], i));
254- }
255-
256- gnnInput.emplace_back (createTensor (feat, featShape));
257-
258- if (numNaN > 0 ) {
259- LOGF (info, " NaN found in GNN input feature, number of NaN: %d" , numNaN);
260- }
261- }
98+ double SVpT = 0.0 ;
99+ double DeltaRSVJet = 0.0 ;
100+ double SVMass = 0.0 ;
101+ double SVfE = 0.0 ;
102+ double IPxy = 0.0 ;
103+ double CPA = 0.0 ;
104+ double Chi2PCA = 0.0 ;
105+ double Dispersion = 0.0 ;
106+ double DecayLength2D = 0.0 ;
107+ double DecayLength2DError = 0.0 ;
108+ double DecayLength3D = 0.0 ;
109+ double DecayLength3DError = 0.0 ;
262110};
263111
264112// ________________________________________________________________________
265113bool isBHadron (int pc)
266114{
115+ using o2::constants::physics::Pdg;
267116 std::vector<int > bPdG = {Pdg::kB0 , Pdg::kBPlus , 10511 , 10521 , 513 , 523 , 10513 , 10523 , 20513 , 20523 , 20513 , 20523 , 515 , 525 , Pdg::kBS , 10531 , 533 , 10533 ,
268117 20533 , 535 , 541 , 10541 , 543 , 10543 , 20543 , 545 , 551 , 10551 , 100551 , 110551 , 200551 , 210551 , 553 , 10553 , 20553 ,
269118 30553 , 100553 , 110553 , 120553 , 130553 , 200553 , 210553 , 220553 , 300553 , 9000533 , 9010553 , 555 , 10555 , 20555 ,
270- 100555 , 110555 , 120555 , 200555 , 557 , 100557 , Pdg::kLambdaB0 , 5112 , 5212 , 5222 , 5114 , 5214 , 5224 , 5132 , kXiB0 , 5312 , 5322 ,
119+ 100555 , 110555 , 120555 , 200555 , 557 , 100557 , Pdg::kLambdaB0 , 5112 , 5212 , 5222 , 5114 , 5214 , 5224 , 5132 , Pdg:: kXiB0 , 5312 , 5322 ,
271120 5314 , 5324 , 5332 , 5334 , 5142 , 5242 , 5412 , 5422 , 5414 , 5424 , 5342 , 5432 , 5434 , 5442 , 5444 , 5512 , 5522 , 5514 , 5524 ,
272121 5532 , 5534 , 5542 , 5544 , 5554 };
273122
@@ -276,6 +125,7 @@ bool isBHadron(int pc)
276125// ________________________________________________________________________
277126bool isCHadron (int pc)
278127{
128+ using o2::constants::physics::Pdg;
279129 std::vector<int > bPdG = {Pdg::kDPlus , Pdg::kD0 , Pdg::kD0StarPlus , Pdg::kD0Star0 , 413 , 423 , 10413 , 10423 , 20431 , 20423 , Pdg::kD2StarPlus , Pdg::kD2Star0 , Pdg::kDS , 10431 , Pdg::kDSStar , Pdg::kDS1 , 20433 , Pdg::kDS2Star , 441 ,
280130 10441 , 100441 , Pdg::kJPsi , 10443 , Pdg::kChiC1 , 100443 , 30443 , 9000443 , 9010443 , 9020443 , 445 , 100445 , Pdg::kLambdaCPlus , Pdg::kSigmaCPlusPlus , 4212 , Pdg::kSigmaC0 ,
281131 4224 , 4214 , 4114 , Pdg::kXiCPlus , Pdg::kXiC0 , 4322 , 4312 , 4324 , 4314 , Pdg::kOmegaC0 , 4334 , 4412 , Pdg::kXiCCPlusPlus , 4414 , 4424 , 4432 , 4434 , 4444 };
@@ -1106,48 +956,6 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
1106956 return nVertices;
1107957}
1108958
1109- std::vector<std::vector<float >> getInputsForML (BJetParams jetparams, std::vector<BJetTrackParams>& tracksParams, std::vector<BJetSVParams>& svsParams, int maxJetConst = 10 )
1110- {
1111- std::vector<float > jetInput = {jetparams.mJetpT , jetparams.mJetEta , jetparams.mJetPhi , static_cast <float >(jetparams.mNTracks ), static_cast <float >(jetparams.mNSV ), jetparams.mJetMass };
1112- std::vector<float > tracksInputFlat;
1113- std::vector<float > svsInputFlat;
1114-
1115- for (int iconstit = 0 ; iconstit < maxJetConst; iconstit++) {
1116-
1117- tracksInputFlat.push_back (tracksParams[iconstit].mTrackpT );
1118- tracksInputFlat.push_back (tracksParams[iconstit].mTrackEta );
1119- tracksInputFlat.push_back (tracksParams[iconstit].mDotProdTrackJet );
1120- tracksInputFlat.push_back (tracksParams[iconstit].mDotProdTrackJetOverJet );
1121- tracksInputFlat.push_back (tracksParams[iconstit].mDeltaRJetTrack );
1122- tracksInputFlat.push_back (tracksParams[iconstit].mSignedIP2D );
1123- tracksInputFlat.push_back (tracksParams[iconstit].mSignedIP2DSign );
1124- tracksInputFlat.push_back (tracksParams[iconstit].mSignedIP3D );
1125- tracksInputFlat.push_back (tracksParams[iconstit].mSignedIP3DSign );
1126- tracksInputFlat.push_back (tracksParams[iconstit].mMomFraction );
1127- tracksInputFlat.push_back (tracksParams[iconstit].mDeltaRTrackVertex );
1128-
1129- svsInputFlat.push_back (svsParams[iconstit].mSVpT );
1130- svsInputFlat.push_back (svsParams[iconstit].mDeltaRSVJet );
1131- svsInputFlat.push_back (svsParams[iconstit].mSVMass );
1132- svsInputFlat.push_back (svsParams[iconstit].mSVfE );
1133- svsInputFlat.push_back (svsParams[iconstit].mIPXY );
1134- svsInputFlat.push_back (svsParams[iconstit].mCPA );
1135- svsInputFlat.push_back (svsParams[iconstit].mChi2PCA );
1136- svsInputFlat.push_back (svsParams[iconstit].mDispersion );
1137- svsInputFlat.push_back (svsParams[iconstit].mDecayLength2D );
1138- svsInputFlat.push_back (svsParams[iconstit].mDecayLength2DError );
1139- svsInputFlat.push_back (svsParams[iconstit].mDecayLength3D );
1140- svsInputFlat.push_back (svsParams[iconstit].mDecayLength3DError );
1141- }
1142-
1143- std::vector<std::vector<float >> totalInput;
1144- totalInput.push_back (jetInput);
1145- totalInput.push_back (tracksInputFlat);
1146- totalInput.push_back (svsInputFlat);
1147-
1148- return totalInput;
1149- }
1150-
1151959// Looping over the SV info and putting them in the input vector
1152960template <typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
1153961void analyzeJetSVInfo4ML (AnalysisJet const & myJet, AnyTracks const & /* allTracks*/ , SecondaryVertices const & /* allSVs*/ , std::vector<BJetSVParams>& svsParams, float svPtMin = 1.0 , int svReductionFactor = 3 )
@@ -1193,7 +1001,7 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
11931001
11941002 double deltaRJetTrack = jetutilities::deltaR (analysisJet, constituent);
11951003 double dotProduct = RecoDecay::dotProd (std::array<float , 3 >{analysisJet.px (), analysisJet.py (), analysisJet.pz ()}, std::array<float , 3 >{constituent.px (), constituent.py (), constituent.pz ()});
1196- int sign = jettaggingutilities:: getGeoSign (analysisJet, constituent);
1004+ int sign = getGeoSign (analysisJet, constituent);
11971005
11981006 float rClosestSV = 10 .;
11991007 for (const auto & candSV : analysisJet.template secondaryVertices_as <SecondaryVertices>()) {
@@ -1207,7 +1015,32 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
12071015 }
12081016
12091017 auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
1210- return (tr1.mSignedIP2D / tr1.mSignedIP2DSign ) > (tr2.mSignedIP2D / tr2.mSignedIP2DSign );
1018+ return (tr1.SignedIP2D / tr1.SignedIP2DSign ) > (tr2.SignedIP2D / tr2.SignedIP2DSign );
1019+ };
1020+
1021+ // Sort the tracks based on their IP significance in descending order
1022+ std::sort (tracksParams.begin (), tracksParams.end (), compare);
1023+ }
1024+
1025+ // Looping over the track info and putting them in the input vector without using any SV info
1026+ template <typename AnalysisJet, typename AnyTracks>
1027+ void analyzeJetTrackInfo4MLnoSV (AnalysisJet const & analysisJet, AnyTracks const & /* allTracks*/ , std::vector<BJetTrackParams>& tracksParams, float trackPtMin = 0.5 )
1028+ {
1029+ for (const auto & constituent : analysisJet.template tracks_as <AnyTracks>()) {
1030+
1031+ if (constituent.pt () < trackPtMin) {
1032+ continue ;
1033+ }
1034+
1035+ double deltaRJetTrack = jetutilities::deltaR (analysisJet, constituent);
1036+ double dotProduct = RecoDecay::dotProd (std::array<float , 3 >{analysisJet.px (), analysisJet.py (), analysisJet.pz ()}, std::array<float , 3 >{constituent.px (), constituent.py (), constituent.pz ()});
1037+ int sign = getGeoSign (analysisJet, constituent);
1038+
1039+ tracksParams.emplace_back (BJetTrackParams{constituent.pt (), constituent.eta (), dotProduct, dotProduct / analysisJet.p (), deltaRJetTrack, std::abs (constituent.dcaXY ()) * sign, constituent.sigmadcaXY (), std::abs (constituent.dcaXYZ ()) * sign, constituent.sigmadcaXYZ (), constituent.p () / analysisJet.p (), 0.0 });
1040+ }
1041+
1042+ auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
1043+ return (tr1.SignedIP2D / tr1.SignedIP2DSign ) > (tr2.SignedIP2D / tr2.SignedIP2DSign );
12111044 };
12121045
12131046 // Sort the tracks based on their IP significance in descending order
@@ -1224,7 +1057,7 @@ void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*
12241057 continue ;
12251058 }
12261059
1227- int sign = jettaggingutilities:: getGeoSign (analysisJet, constituent);
1060+ int sign = getGeoSign (analysisJet, constituent);
12281061
12291062 auto origConstit = constituent.template track_as <AnyOriginalTracks>();
12301063
@@ -1245,7 +1078,7 @@ void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*
12451078
12461079// Discriminant value for GNN b-jet tagging
12471080template <typename T>
1248- T Db (const std::vector<T>& logits, double fC = 0.018 )
1081+ T getDb (const std::vector<T>& logits, double fC = 0.018 )
12491082{
12501083 auto softmax = [](const std::vector<T>& logits) {
12511084 std::vector<T> res;
0 commit comments