Skip to content

Commit 71cfc02

Browse files
authored
[PWGJE] Implementing a utility class for b-jet tagging ML methods (#9580)
1 parent da115a5 commit 71cfc02

File tree

5 files changed

+606
-273
lines changed

5 files changed

+606
-273
lines changed

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 70 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@
3838
#include "Common/Core/trackUtilities.h"
3939
#include "PWGJE/Core/JetUtilities.h"
4040

41-
#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
42-
#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
43-
#else
44-
#include <onnxruntime_cxx_api.h>
45-
#endif
46-
47-
using namespace o2::constants::physics;
48-
4941
enum JetTaggingSpecies {
5042
none = 0,
5143
charm = 1,
@@ -71,203 +63,60 @@ namespace jettaggingutilities
7163
const int cmTomum = 10000; // using cm -> #mum for impact parameter (dca)
7264

7365
struct BJetParams {
74-
float mJetpT = 0.0;
75-
float mJetEta = 0.0;
76-
float mJetPhi = 0.0;
77-
int mNTracks = -1;
78-
int mNSV = -1;
79-
float mJetMass = 0.0;
66+
float JetpT = 0.0;
67+
float JetEta = 0.0;
68+
float JetPhi = 0.0;
69+
int NTracks = -1;
70+
int NSV = -1;
71+
float JetMass = 0.0;
8072
};
8173

8274
struct BJetTrackParams {
83-
double mTrackpT = 0.0;
84-
double mTrackEta = 0.0;
85-
double mDotProdTrackJet = 0.0;
86-
double mDotProdTrackJetOverJet = 0.0;
87-
double mDeltaRJetTrack = 0.0;
88-
double mSignedIP2D = 0.0;
89-
double mSignedIP2DSign = 0.0;
90-
double mSignedIP3D = 0.0;
91-
double mSignedIP3DSign = 0.0;
92-
double mMomFraction = 0.0;
93-
double mDeltaRTrackVertex = 0.0;
75+
double TrackpT = 0.0;
76+
double TrackEta = 0.0;
77+
double DotProdTrackJet = 0.0;
78+
double DotProdTrackJetOverJet = 0.0;
79+
double DeltaRJetTrack = 0.0;
80+
double SignedIP2D = 0.0;
81+
double SignedIP2DSign = 0.0;
82+
double SignedIP3D = 0.0;
83+
double SignedIP3DSign = 0.0;
84+
double MomFraction = 0.0;
85+
double DeltaRTrackVertex = 0.0;
86+
double TrackPhi = 0.0;
87+
double TrackCharge = 0.0;
88+
double TrackITSChi2NCl = 0.0;
89+
double TrackTPCChi2NCl = 0.0;
90+
double TrackITSNCls = 0.0;
91+
double TrackTPCNCls = 0.0;
92+
double TrackTPCNCrossedRows = 0.0;
93+
int TrackOrigin = -1;
94+
int TrackVtxIndex = -1;
9495
};
9596

9697
struct BJetSVParams {
97-
double mSVpT = 0.0;
98-
double mDeltaRSVJet = 0.0;
99-
double mSVMass = 0.0;
100-
double mSVfE = 0.0;
101-
double mIPXY = 0.0;
102-
double mCPA = 0.0;
103-
double mChi2PCA = 0.0;
104-
double mDispersion = 0.0;
105-
double mDecayLength2D = 0.0;
106-
double mDecayLength2DError = 0.0;
107-
double mDecayLength3D = 0.0;
108-
double mDecayLength3DError = 0.0;
109-
};
110-
111-
// ONNX Runtime tensor (Ort::Value) allocator for using customized inputs of ML models.
112-
class TensorAllocator
113-
{
114-
protected:
115-
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
116-
Ort::MemoryInfo mem_info;
117-
#endif
118-
public:
119-
TensorAllocator()
120-
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
121-
: mem_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
122-
#endif
123-
{
124-
}
125-
~TensorAllocator() = default;
126-
template <typename T>
127-
Ort::Value createTensor(std::vector<T>& input, std::vector<int64_t>& inputShape)
128-
{
129-
#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
130-
return Ort::Experimental::Value::CreateTensor<T>(input.data(), input.size(), inputShape);
131-
#else
132-
return Ort::Value::CreateTensor<T>(mem_info, input.data(), input.size(), inputShape.data(), inputShape.size());
133-
#endif
134-
}
135-
};
136-
137-
// TensorAllocator for GNN b-jet tagger
138-
class GNNBjetAllocator : public TensorAllocator
139-
{
140-
private:
141-
int64_t nJetFeat;
142-
int64_t nTrkFeat;
143-
int64_t nFlav;
144-
int64_t nTrkOrigin;
145-
int64_t maxNNodes;
146-
147-
std::vector<float> tfJetMean;
148-
std::vector<float> tfJetStdev;
149-
std::vector<float> tfTrkMean;
150-
std::vector<float> tfTrkStdev;
151-
152-
std::vector<std::vector<int64_t>> edgesList;
153-
154-
// Jet feature normalization
155-
template <typename T>
156-
T jetFeatureTransform(T feat, int idx) const
157-
{
158-
return (feat - tfJetMean[idx]) / tfJetStdev[idx];
159-
}
160-
161-
// Track feature normalization
162-
template <typename T>
163-
T trkFeatureTransform(T feat, int idx) const
164-
{
165-
return (feat - tfTrkMean[idx]) / tfTrkStdev[idx];
166-
}
167-
168-
// Edge input of GNN (fully-connected graph)
169-
void setEdgesList(void)
170-
{
171-
for (int64_t nNodes = 0; nNodes <= maxNNodes; ++nNodes) {
172-
std::vector<std::pair<int64_t, int64_t>> edges;
173-
// Generate all permutations of (i, j) where i != j
174-
for (int64_t i = 0; i < nNodes; ++i) {
175-
for (int64_t j = 0; j < nNodes; ++j) {
176-
if (i != j) {
177-
edges.emplace_back(i, j);
178-
}
179-
}
180-
}
181-
// Add self-loops (i, i)
182-
for (int64_t i = 0; i < nNodes; ++i) {
183-
edges.emplace_back(i, i);
184-
}
185-
// Flatten
186-
std::vector<int64_t> flattenedEdges;
187-
for (const auto& edge : edges) {
188-
flattenedEdges.push_back(edge.first);
189-
}
190-
for (const auto& edge : edges) {
191-
flattenedEdges.push_back(edge.second);
192-
}
193-
edgesList.push_back(flattenedEdges);
194-
}
195-
}
196-
197-
// Replace NaN in a vector into value
198-
template <typename T>
199-
static int replaceNaN(std::vector<T>& vec, T value)
200-
{
201-
int numNaN = 0;
202-
for (auto& el : vec) {
203-
if (std::isnan(el)) {
204-
el = value;
205-
++numNaN;
206-
}
207-
}
208-
return numNaN;
209-
}
210-
211-
public:
212-
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {}
213-
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40)
214-
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
215-
{
216-
setEdgesList();
217-
}
218-
~GNNBjetAllocator() = default;
219-
220-
// Copy operator for initializing GNNBjetAllocator using Configurable values
221-
GNNBjetAllocator& operator=(const GNNBjetAllocator& other)
222-
{
223-
nJetFeat = other.nJetFeat;
224-
nTrkFeat = other.nTrkFeat;
225-
nFlav = other.nFlav;
226-
nTrkOrigin = other.nTrkOrigin;
227-
maxNNodes = other.maxNNodes;
228-
tfJetMean = other.tfJetMean;
229-
tfJetStdev = other.tfJetStdev;
230-
tfTrkMean = other.tfTrkMean;
231-
tfTrkStdev = other.tfTrkStdev;
232-
setEdgesList();
233-
return *this;
234-
}
235-
236-
// Allocate & Return GNN input tensors (std::vector<Ort::Value>)
237-
template <typename T>
238-
void getGNNInput(std::vector<T>& jetFeat, std::vector<std::vector<T>>& trkFeat, std::vector<T>& feat, std::vector<Ort::Value>& gnnInput)
239-
{
240-
int64_t nNodes = trkFeat.size();
241-
242-
std::vector<int64_t> edgesShape{2, nNodes * nNodes};
243-
gnnInput.emplace_back(createTensor(edgesList[nNodes], edgesShape));
244-
245-
std::vector<int64_t> featShape{nNodes, nJetFeat + nTrkFeat};
246-
247-
int numNaN = replaceNaN(jetFeat, 0.f);
248-
for (auto& aTrkFeat : trkFeat) {
249-
for (size_t i = 0; i < jetFeat.size(); ++i)
250-
feat.push_back(jetFeatureTransform(jetFeat[i], i));
251-
numNaN += replaceNaN(aTrkFeat, 0.f);
252-
for (size_t i = 0; i < aTrkFeat.size(); ++i)
253-
feat.push_back(trkFeatureTransform(aTrkFeat[i], i));
254-
}
255-
256-
gnnInput.emplace_back(createTensor(feat, featShape));
257-
258-
if (numNaN > 0) {
259-
LOGF(info, "NaN found in GNN input feature, number of NaN: %d", numNaN);
260-
}
261-
}
98+
double SVpT = 0.0;
99+
double DeltaRSVJet = 0.0;
100+
double SVMass = 0.0;
101+
double SVfE = 0.0;
102+
double IPxy = 0.0;
103+
double CPA = 0.0;
104+
double Chi2PCA = 0.0;
105+
double Dispersion = 0.0;
106+
double DecayLength2D = 0.0;
107+
double DecayLength2DError = 0.0;
108+
double DecayLength3D = 0.0;
109+
double DecayLength3DError = 0.0;
262110
};
263111

264112
//________________________________________________________________________
265113
bool isBHadron(int pc)
266114
{
115+
using o2::constants::physics::Pdg;
267116
std::vector<int> bPdG = {Pdg::kB0, Pdg::kBPlus, 10511, 10521, 513, 523, 10513, 10523, 20513, 20523, 20513, 20523, 515, 525, Pdg::kBS, 10531, 533, 10533,
268117
20533, 535, 541, 10541, 543, 10543, 20543, 545, 551, 10551, 100551, 110551, 200551, 210551, 553, 10553, 20553,
269118
30553, 100553, 110553, 120553, 130553, 200553, 210553, 220553, 300553, 9000533, 9010553, 555, 10555, 20555,
270-
100555, 110555, 120555, 200555, 557, 100557, Pdg::kLambdaB0, 5112, 5212, 5222, 5114, 5214, 5224, 5132, kXiB0, 5312, 5322,
119+
100555, 110555, 120555, 200555, 557, 100557, Pdg::kLambdaB0, 5112, 5212, 5222, 5114, 5214, 5224, 5132, Pdg::kXiB0, 5312, 5322,
271120
5314, 5324, 5332, 5334, 5142, 5242, 5412, 5422, 5414, 5424, 5342, 5432, 5434, 5442, 5444, 5512, 5522, 5514, 5524,
272121
5532, 5534, 5542, 5544, 5554};
273122

@@ -276,6 +125,7 @@ bool isBHadron(int pc)
276125
//________________________________________________________________________
277126
bool isCHadron(int pc)
278127
{
128+
using o2::constants::physics::Pdg;
279129
std::vector<int> bPdG = {Pdg::kDPlus, Pdg::kD0, Pdg::kD0StarPlus, Pdg::kD0Star0, 413, 423, 10413, 10423, 20431, 20423, Pdg::kD2StarPlus, Pdg::kD2Star0, Pdg::kDS, 10431, Pdg::kDSStar, Pdg::kDS1, 20433, Pdg::kDS2Star, 441,
280130
10441, 100441, Pdg::kJPsi, 10443, Pdg::kChiC1, 100443, 30443, 9000443, 9010443, 9020443, 445, 100445, Pdg::kLambdaCPlus, Pdg::kSigmaCPlusPlus, 4212, Pdg::kSigmaC0,
281131
4224, 4214, 4114, Pdg::kXiCPlus, Pdg::kXiC0, 4322, 4312, 4324, 4314, Pdg::kOmegaC0, 4334, 4412, Pdg::kXiCCPlusPlus, 4414, 4424, 4432, 4434, 4444};
@@ -1106,48 +956,6 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
1106956
return nVertices;
1107957
}
1108958

1109-
std::vector<std::vector<float>> getInputsForML(BJetParams jetparams, std::vector<BJetTrackParams>& tracksParams, std::vector<BJetSVParams>& svsParams, int maxJetConst = 10)
1110-
{
1111-
std::vector<float> jetInput = {jetparams.mJetpT, jetparams.mJetEta, jetparams.mJetPhi, static_cast<float>(jetparams.mNTracks), static_cast<float>(jetparams.mNSV), jetparams.mJetMass};
1112-
std::vector<float> tracksInputFlat;
1113-
std::vector<float> svsInputFlat;
1114-
1115-
for (int iconstit = 0; iconstit < maxJetConst; iconstit++) {
1116-
1117-
tracksInputFlat.push_back(tracksParams[iconstit].mTrackpT);
1118-
tracksInputFlat.push_back(tracksParams[iconstit].mTrackEta);
1119-
tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJet);
1120-
tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJetOverJet);
1121-
tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRJetTrack);
1122-
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2D);
1123-
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2DSign);
1124-
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3D);
1125-
tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3DSign);
1126-
tracksInputFlat.push_back(tracksParams[iconstit].mMomFraction);
1127-
tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRTrackVertex);
1128-
1129-
svsInputFlat.push_back(svsParams[iconstit].mSVpT);
1130-
svsInputFlat.push_back(svsParams[iconstit].mDeltaRSVJet);
1131-
svsInputFlat.push_back(svsParams[iconstit].mSVMass);
1132-
svsInputFlat.push_back(svsParams[iconstit].mSVfE);
1133-
svsInputFlat.push_back(svsParams[iconstit].mIPXY);
1134-
svsInputFlat.push_back(svsParams[iconstit].mCPA);
1135-
svsInputFlat.push_back(svsParams[iconstit].mChi2PCA);
1136-
svsInputFlat.push_back(svsParams[iconstit].mDispersion);
1137-
svsInputFlat.push_back(svsParams[iconstit].mDecayLength2D);
1138-
svsInputFlat.push_back(svsParams[iconstit].mDecayLength2DError);
1139-
svsInputFlat.push_back(svsParams[iconstit].mDecayLength3D);
1140-
svsInputFlat.push_back(svsParams[iconstit].mDecayLength3DError);
1141-
}
1142-
1143-
std::vector<std::vector<float>> totalInput;
1144-
totalInput.push_back(jetInput);
1145-
totalInput.push_back(tracksInputFlat);
1146-
totalInput.push_back(svsInputFlat);
1147-
1148-
return totalInput;
1149-
}
1150-
1151959
// Looping over the SV info and putting them in the input vector
1152960
template <typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
1153961
void analyzeJetSVInfo4ML(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<BJetSVParams>& svsParams, float svPtMin = 1.0, int svReductionFactor = 3)
@@ -1193,7 +1001,7 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
11931001

11941002
double deltaRJetTrack = jetutilities::deltaR(analysisJet, constituent);
11951003
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
1196-
int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);
1004+
int sign = getGeoSign(analysisJet, constituent);
11971005

11981006
float rClosestSV = 10.;
11991007
for (const auto& candSV : analysisJet.template secondaryVertices_as<SecondaryVertices>()) {
@@ -1207,7 +1015,32 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
12071015
}
12081016

12091017
auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
1210-
return (tr1.mSignedIP2D / tr1.mSignedIP2DSign) > (tr2.mSignedIP2D / tr2.mSignedIP2DSign);
1018+
return (tr1.SignedIP2D / tr1.SignedIP2DSign) > (tr2.SignedIP2D / tr2.SignedIP2DSign);
1019+
};
1020+
1021+
// Sort the tracks based on their IP significance in descending order
1022+
std::sort(tracksParams.begin(), tracksParams.end(), compare);
1023+
}
1024+
1025+
// Looping over the track info and putting them in the input vector without using any SV info
1026+
template <typename AnalysisJet, typename AnyTracks>
1027+
void analyzeJetTrackInfo4MLnoSV(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, std::vector<BJetTrackParams>& tracksParams, float trackPtMin = 0.5)
1028+
{
1029+
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {
1030+
1031+
if (constituent.pt() < trackPtMin) {
1032+
continue;
1033+
}
1034+
1035+
double deltaRJetTrack = jetutilities::deltaR(analysisJet, constituent);
1036+
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
1037+
int sign = getGeoSign(analysisJet, constituent);
1038+
1039+
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.0});
1040+
}
1041+
1042+
auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
1043+
return (tr1.SignedIP2D / tr1.SignedIP2DSign) > (tr2.SignedIP2D / tr2.SignedIP2DSign);
12111044
};
12121045

12131046
// Sort the tracks based on their IP significance in descending order
@@ -1224,7 +1057,7 @@ void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*
12241057
continue;
12251058
}
12261059

1227-
int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);
1060+
int sign = getGeoSign(analysisJet, constituent);
12281061

12291062
auto origConstit = constituent.template track_as<AnyOriginalTracks>();
12301063

@@ -1245,7 +1078,7 @@ void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*
12451078

12461079
// Discriminant value for GNN b-jet tagging
12471080
template <typename T>
1248-
T Db(const std::vector<T>& logits, double fC = 0.018)
1081+
T getDb(const std::vector<T>& logits, double fC = 0.018)
12491082
{
12501083
auto softmax = [](const std::vector<T>& logits) {
12511084
std::vector<T> res;

0 commit comments

Comments
 (0)