Skip to content

Commit a71be96

Browse files
choich08365Changhwan Choialibuild
authored
[PWGJE] GNN b-jet updated (#14334)
Co-authored-by: Changhwan Choi <changhwan.choi@cern.ch> Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent 3faeaa7 commit a71be96

File tree

6 files changed

+677
-270
lines changed

6 files changed

+677
-270
lines changed

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,11 +1097,11 @@ void analyzeJetTrackInfo4MLnoSV(AnalysisJet const& analysisJet, AnyTracks const&
10971097

10981098
// Looping over the track info and putting them in the input vector (for GNN b-jet tagging)
10991099
template <typename AnalysisJet, typename AnyTracks, typename AnyOriginalTracks>
1100-
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, int64_t nMaxConstit = 40)
1100+
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, float trackDcaXYMax = 10.0, float trackDcaZMax = 10.0, int64_t nMaxConstit = 40)
11011101
{
11021102
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {
11031103

1104-
if (constituent.pt() < trackPtMin) {
1104+
if (constituent.pt() < trackPtMin || !trackAcceptanceWithDca(constituent, trackDcaXYMax, trackDcaZMax)) {
11051105
continue;
11061106
}
11071107

PWGJE/Core/MlResponseHfTagging.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <onnxruntime_c_api.h>
2424
#include <onnxruntime_cxx_api.h>
2525

26+
#include <cmath>
2627
#include <cstddef>
2728
#include <cstdint>
2829
#include <utility>
@@ -208,7 +209,7 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
208209
static int replaceNaN(std::vector<T>& vec, T value)
209210
{
210211
int numNaN = 0;
211-
for (auto& el : vec) {
212+
for (auto& el : vec) { // o2-linter: disable=const-ref-in-for-loop
212213
if (std::isnan(el) || std::isinf(el)) {
213214
el = value;
214215
++numNaN;
@@ -364,14 +365,14 @@ class GNNBjetAllocator : public TensorAllocator
364365
template <typename T>
365366
T jetFeatureTransform(T feat, int idx) const
366367
{
367-
return (feat - tfJetMean[idx]) / tfJetStdev[idx];
368+
return std::tanh((feat - tfJetMean[idx]) / tfJetStdev[idx]);
368369
}
369370

370371
// Track feature normalization
371372
template <typename T>
372373
T trkFeatureTransform(T feat, int idx) const
373374
{
374-
return (feat - tfTrkMean[idx]) / tfTrkStdev[idx];
375+
return std::tanh((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
375376
}
376377

377378
// Edge input of GNN (fully-connected graph)

PWGJE/TableProducer/jetTaggerHF.cxx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ struct JetTaggerHFTask {
615615
{
616616
for (const auto& jet : jets) {
617617
std::vector<std::vector<float>> trkFeat;
618-
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, nJetConst);
618+
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, trackDcaXYMax, trackDcaZMax, nJetConst);
619619

620620
std::vector<float> jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()};
621621

@@ -625,7 +625,13 @@ struct JetTaggerHFTask {
625625
tensorAlloc.getGNNInput(jetFeat, trkFeat, feat, gnnInput);
626626

627627
auto modelOutput = bMlResponse.getModelOutput(gnnInput, 0);
628-
scoreML[jet.globalIndex()] = jettaggingutilities::getDb(modelOutput, fC);
628+
float db = jettaggingutilities::getDb(modelOutput, fC);
629+
if (!std::isnan(db)) {
630+
scoreML[jet.globalIndex()] = db;
631+
} else {
632+
scoreML[jet.globalIndex()] = 999.;
633+
LOGF(debug, "doprocessAlgorithmGNN, Db is NaN (%d)", jet.globalIndex());
634+
}
629635
} else {
630636
scoreML[jet.globalIndex()] = -999.;
631637
LOGF(debug, "doprocessAlgorithmGNN, trkFeat.size() <= 0 (%d)", jet.globalIndex());

PWGJE/Tasks/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ if(FastJet_FOUND)
336336
COMPONENT_NAME Analysis)
337337
o2physics_add_dpl_workflow(bjet-tagging-gnn
338338
SOURCES bjetTaggingGnn.cxx
339-
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::PWGJECore O2Physics::AnalysisCore
339+
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::PWGJECore O2Physics::AnalysisCore O2Physics::EventFilteringUtils
340+
O2Physics::AnalysisCCDB
340341
COMPONENT_NAME Analysis)
341342
o2physics_add_dpl_workflow(jet-shape
342343
SOURCES jetShape.cxx

0 commit comments

Comments
 (0)