33#include <vector>
44#include <fstream>
55#include <rapidjson/document.h>
6- #include <TMatrixT.h>
6+
7+ // Static Ort::Env instance for multiple onnx model loading
8+ static Ort ::Env global_env (ORT_LOGGING_LEVEL_WARNING , "GlobalEnv ");
79
810// This class is responsible for loading the scaler parameters from a JSON file
911// and applying the inverse transformation to the generated data.
@@ -69,8 +71,8 @@ private:
6971class ONNXGenerator
7072{
7173public :
72- ONNXGenerator (const std ::string & model_path )
73- : env (ORT_LOGGING_LEVEL_WARNING , "ONNXGenerator" ), session (env , model_path .c_str (), Ort ::SessionOptions {})
74+ ONNXGenerator (Ort :: Env & shared_env , const std ::string & model_path )
75+ : env (shared_env ), session (env , model_path .c_str (), Ort ::SessionOptions {})
7476 {
7577 // Create session options
7678 Ort ::SessionOptions session_options ;
@@ -114,7 +116,7 @@ public:
114116 }
115117
116118private :
117- Ort ::Env env ;
119+ Ort ::Env & env ;
118120 Ort ::Session session ;
119121 TRandom3 rand_gen ;
120122};
@@ -195,10 +197,10 @@ class GenTPCLoopers : public Generator
195197 mGaussSet = true;
196198 }
197199 }
198- mONNX_pair = std ::make_unique < ONNXGenerator > (model_pairs );
200+ mONNX_pair = std ::make_unique < ONNXGenerator > (global_env , model_pairs );
199201 mScaler_pair = std ::make_unique < Scaler > ( );
200202 mScaler_pair -> load (scaler_pair );
201- mONNX_compton = std ::make_unique < ONNXGenerator > (model_compton );
203+ mONNX_compton = std ::make_unique < ONNXGenerator > (global_env , model_compton );
202204 mScaler_compton = std ::make_unique < Scaler > ( );
203205 mScaler_compton -> load (scaler_compton );
204206 Generator ::setTimeUnit (1.0 );
@@ -214,11 +216,11 @@ class GenTPCLoopers : public Generator
214216 // Set number of loopers if poissonian params are available
215217 if (mPoissonSet )
216218 {
217- mNLoopersPairs = PoissonPairs ();
219+ mNLoopersPairs = static_cast < short int > ( std :: round ( mMultiplier [ 0 ] * PoissonPairs ()) );
218220 }
219221 if (mGaussSet )
220222 {
221- mNLoopersCompton = GaussianElectrons ();
223+ mNLoopersCompton = static_cast < short int > ( std :: round ( mMultiplier [ 1 ] * GaussianElectrons ()) );
222224 }
223225 // Generate pairs
224226 for (int i = 0 ; i < mNLoopersPairs ; ++ i )
@@ -321,7 +323,7 @@ class GenTPCLoopers : public Generator
321323 return gaussValue ;
322324 }
323325
324- void SetNLoopers (short int nsig_pair , short int nsig_compton )
326+ void SetNLoopers (short int & nsig_pair , short int & nsig_compton )
325327 {
326328 if (mPoissonSet ) {
327329 LOG (info ) << "Poissonian parameters correctly loaded." ;
@@ -335,6 +337,21 @@ class GenTPCLoopers : public Generator
335337 }
336338 }
337339
340+ void SetMultiplier (std ::array < float , 2 > & mult )
341+ {
342+ // Multipliers will work only if the poissonian and gaussian parameters are set
343+ // otherwise they will be ignored
344+ if (mult [0 ] < 0 || mult [1 ] < 0 )
345+ {
346+ LOG (fatal ) << "Error: Multiplier values must be non-negative!" ;
347+ exit (1 );
348+ } else {
349+ LOG (info ) << "Multiplier values set to: Pair = " << mult [0 ] << ", Compton = " << mult [1 ];
350+ mMultiplier [0 ] = mult [0 ];
351+ mMultiplier [1 ] = mult [1 ];
352+ }
353+ }
354+
338355 private :
339356 std ::unique_ptr < ONNXGenerator > mONNX_pair = nullptr ;
340357 std ::unique_ptr < ONNXGenerator > mONNX_compton = nullptr ;
@@ -346,6 +363,7 @@ class GenTPCLoopers : public Generator
346363 std ::vector < std ::vector < double >> mGenElectrons ;
347364 short int mNLoopersPairs = -1 ;
348365 short int mNLoopersCompton = -1 ;
366+ std ::array < float , 2 > mMultiplier = {1. , 1. };
349367 bool mPoissonSet = false;
350368 bool mGaussSet = false;
351369 // Random number generator
@@ -362,7 +380,8 @@ class GenTPCLoopers : public Generator
362380FairGenerator *
363381 Generator_TPCLoopers (std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
364382 std ::string poisson = "poisson.csv" , std ::string gauss = "gauss.csv" , std ::string scaler_pair = "scaler_pair.json" ,
365- std ::string scaler_compton = "scaler_compton.json" , short int nloopers_pairs = 1 , short int nloopers_compton = 1 )
383+ std ::string scaler_compton = "scaler_compton.json" , std ::array < float , 2 > mult = {1. , 1. }, short int nloopers_pairs = 1 ,
384+ short int nloopers_compton = 1 )
366385{
367386 // Expand all environment paths
368387 model_pairs = gSystem -> ExpandPathName (model_pairs .c_str ()) ;
@@ -373,5 +392,6 @@ FairGenerator *
373392 scaler_compton = gSystem -> ExpandPathName (scaler_compton .c_str ());
374393 auto generator = new o2 ::eventgen ::GenTPCLoopers (model_pairs , model_compton , poisson , gauss , scaler_pair , scaler_compton );
375394 generator -> SetNLoopers (nloopers_pairs , nloopers_compton );
395+ generator -> SetMultiplier (mult );
376396 return generator ;
377397}
0 commit comments