@@ -361,18 +361,20 @@ class GNNBjetAllocator : public TensorAllocator
361361
362362 std::vector<std::vector<int64_t >> edgesList;
363363
364+ std::function<float (float )> tfFunc;
365+
364366 // Jet feature normalization
365367 template <typename T>
366368 T jetFeatureTransform (T feat, int idx) const
367369 {
368- return std::tanh ((feat - tfJetMean[idx]) / tfJetStdev[idx]);
370+ return tfFunc ((feat - tfJetMean[idx]) / tfJetStdev[idx]);
369371 }
370372
371373 // Track feature normalization
372374 template <typename T>
373375 T trkFeatureTransform (T feat, int idx) const
374376 {
375- return std::tanh ((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
377+ return tfFunc ((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
376378 }
377379
378380 // Edge input of GNN (fully-connected graph)
@@ -419,10 +421,17 @@ class GNNBjetAllocator : public TensorAllocator
419421 }
420422
421423 public:
422- GNNBjetAllocator () : TensorAllocator(), nJetFeat(4 ), nTrkFeat(13 ), nFlav(3 ), nTrkOrigin(5 ), maxNNodes(40 ) {}
423- GNNBjetAllocator (int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float >& tfJetMean, std::vector<float >& tfJetStdev, std::vector<float >& tfTrkMean, std::vector<float >& tfTrkStdev, int64_t maxNNodes = 40 )
424- : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
424+ GNNBjetAllocator () : TensorAllocator(), nJetFeat(4 ), nTrkFeat(13 ), nFlav(3 ), nTrkOrigin(5 ), maxNNodes(40 ), tfFunc([]( float x) { return x; }) {}
425+ GNNBjetAllocator (int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float >& tfJetMean, std::vector<float >& tfJetStdev, std::vector<float >& tfTrkMean, std::vector<float >& tfTrkStdev, int64_t maxNNodes = 40 , std::string tfFuncType = " linear " )
426+ : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev), tfFunc([]( float x) { return x; })
425427 {
428+ if (tfFuncType == " asinh" ) {
429+ tfFunc = [](float x) { return std::asinh (x); };
430+ } else if (tfFuncType == " tanh" ) {
431+ tfFunc = [](float x) { return std::tanh (x); };
432+ } else {
433+ tfFunc = [](float x) { return x; };
434+ }
426435 setEdgesList ();
427436 }
428437 ~GNNBjetAllocator () = default ;
@@ -439,6 +448,8 @@ class GNNBjetAllocator : public TensorAllocator
439448 tfJetStdev = other.tfJetStdev ;
440449 tfTrkMean = other.tfTrkMean ;
441450 tfTrkStdev = other.tfTrkStdev ;
451+ tfFunc = other.tfFunc ;
452+ edgesList.clear ();
442453 setEdgesList ();
443454 return *this ;
444455 }
0 commit comments