Skip to content

Commit f144c25

Browse files
author
Changhwan Choi
committed
Configurable GNN input feature transform function
1 parent 6dd28c1 commit f144c25

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

PWGJE/Core/MlResponseHfTagging.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,18 +361,20 @@ class GNNBjetAllocator : public TensorAllocator
361361

362362
std::vector<std::vector<int64_t>> edgesList;
363363

364+
std::function<float(float)> tfFunc;
365+
364366
// Jet feature normalization
365367
template <typename T>
366368
T jetFeatureTransform(T feat, int idx) const
367369
{
368-
return std::tanh((feat - tfJetMean[idx]) / tfJetStdev[idx]);
370+
return tfFunc((feat - tfJetMean[idx]) / tfJetStdev[idx]);
369371
}
370372

371373
// Track feature normalization
372374
template <typename T>
373375
T trkFeatureTransform(T feat, int idx) const
374376
{
375-
return std::tanh((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
377+
return tfFunc((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
376378
}
377379

378380
// Edge input of GNN (fully-connected graph)
@@ -419,10 +421,17 @@ class GNNBjetAllocator : public TensorAllocator
419421
}
420422

421423
public:
422-
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {}
423-
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40)
424-
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
424+
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40), tfFunc([](float x) { return x; }) {}
425+
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40, std::string tfFuncType = "linear")
426+
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev), tfFunc([](float x) { return x; })
425427
{
428+
if (tfFuncType == "asinh") {
429+
tfFunc = [](float x) { return std::asinh(x); };
430+
} else if (tfFuncType == "tanh") {
431+
tfFunc = [](float x) { return std::tanh(x); };
432+
} else {
433+
tfFunc = [](float x) { return x; };
434+
}
426435
setEdgesList();
427436
}
428437
~GNNBjetAllocator() = default;
@@ -439,6 +448,8 @@ class GNNBjetAllocator : public TensorAllocator
439448
tfJetStdev = other.tfJetStdev;
440449
tfTrkMean = other.tfTrkMean;
441450
tfTrkStdev = other.tfTrkStdev;
451+
tfFunc = other.tfFunc;
452+
edgesList.clear();
442453
setEdgesList();
443454
return *this;
444455
}

PWGJE/TableProducer/jetTaggerHF.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ struct JetTaggerHFTask {
139139
Configurable<std::vector<float>> transformFeatureTrkStdev{"transformFeatureTrkStdev",
140140
std::vector<float>{-999},
141141
"Stdev values for each GNN input feature (track)"};
142+
Configurable<std::string> tfFuncTypeGNN{"tfFuncTypeGNN", "linear", "Transformation function type for GNN"};
142143

143144
// axis spec
144145
ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""};
@@ -525,7 +526,7 @@ struct JetTaggerHFTask {
525526
}
526527

527528
if (doprocessAlgorithmGNN) {
528-
tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst);
529+
tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst, tfFuncTypeGNN.value);
529530

530531
registry.add("h2_count_db", "#it{D}_{b} underflow/overflow;Jet flavour;#it{D}_{b} range", {HistType::kTH2F, {{4, 0., 4.}, {3, 0., 3.}}});
531532
auto h2CountDb = registry.get<TH2>(HIST("h2_count_db"));

0 commit comments

Comments
 (0)