1+ #include <onnxruntime_cxx_api.h>
2+ #include <iostream>
3+ #include <vector>
4+ #include <fstream>
5+ #include <rapidjson/document.h>
6+ #include <TMatrixT.h>
7+
8+ // This class is responsible for loading the scaler parameters from a JSON file
9+ // and applying the inverse transformation to the generated data.
10+ struct Scaler
11+ {
12+ TVectorD normal_min ;
13+ TVectorD normal_max ;
14+ TVectorD outlier_center ;
15+ TVectorD outlier_scale ;
16+
17+ void load (const std ::string & filename )
18+ {
19+ std ::ifstream file (filename );
20+ if (!file .is_open ())
21+ {
22+ throw std ::runtime_error ("Error: Could not open scaler file!" );
23+ }
24+
25+ std ::string json_str ((std ::istreambuf_iterator < char > (file )), std ::istreambuf_iterator < char > ());
26+ file .close ();
27+
28+ rapidjson ::Document doc ;
29+ doc .Parse (json_str .c_str ());
30+
31+ if (doc .HasParseError ())
32+ {
33+ throw std ::runtime_error ("Error: JSON parsing failed!" );
34+ }
35+
36+ // Convert JSON arrays to TVectorD
37+ normal_min .ResizeTo (8 );
38+ normal_max .ResizeTo (8 );
39+ outlier_center .ResizeTo (2 );
40+ outlier_scale .ResizeTo (2 );
41+
42+ jsonArrayToVector (doc ["normal" ]["min" ], normal_min );
43+ jsonArrayToVector (doc ["normal" ]["max" ], normal_max );
44+ jsonArrayToVector (doc ["outlier" ]["center" ], outlier_center );
45+ jsonArrayToVector (doc ["outlier" ]["scale" ], outlier_scale );
46+ }
47+
48+ TVectorD inverse_transform (const TVectorD & input )
49+ {
50+ TVectorD normal_part (8 );
51+ TVectorD outlier_part (2 );
52+
53+ for (int i = 0 ; i < 8 ; ++ i )
54+ {
55+ normal_part [i ] = normal_min [i ] + input [i ] * (normal_max [i ] - normal_min [i ]);
56+ }
57+
58+ for (int i = 0 ; i < 2 ; ++ i )
59+ {
60+ outlier_part [i ] = input [8 + i ] * outlier_scale [i ] + outlier_center [i ];
61+ }
62+
63+ TVectorD output (10 );
64+ for (int i = 0 ; i < 8 ; ++ i )
65+ output [i ] = normal_part [i ];
66+ for (int i = 0 ; i < 2 ; ++ i )
67+ output [8 + i ] = outlier_part [i ];
68+
69+ return output ;
70+ }
71+
72+ private :
73+ void jsonArrayToVector (const rapidjson ::Value & jsonArray , TVectorD & vec )
74+ {
75+ for (int i = 0 ; i < jsonArray .Size (); ++ i )
76+ {
77+ vec [i ] = jsonArray [i ].GetDouble ();
78+ }
79+ }
80+ };
81+
82+ // This class loads the ONNX model and generates samples using it.
83+ class ONNXGenerator
84+ {
85+ public :
86+ ONNXGenerator (const std ::string & model_path )
87+ : env (ORT_LOGGING_LEVEL_WARNING , "ONNXGenerator" ), session (env , model_path .c_str (), Ort ::SessionOptions {})
88+ {
89+ // Create session options
90+ Ort ::SessionOptions session_options ;
91+ session = Ort ::Session (env , model_path .c_str (), session_options );
92+ }
93+
94+ TVectorD generate_sample ()
95+ {
96+ Ort ::AllocatorWithDefaultOptions allocator ;
97+
98+ // Generate a latent vector (z)
99+ std ::vector < float > z (100 );
100+ for (auto & v : z )
101+ v = rand_gen .Gaus (0.0 , 1.0 );
102+
103+ // Prepare input tensor
104+ std ::vector < int64_t > input_shape = {1 , 100 };
105+ // Get memory information
106+ Ort ::MemoryInfo memory_info = Ort ::MemoryInfo ::CreateCpu (OrtArenaAllocator , OrtMemTypeDefault );
107+
108+ // Create input tensor correctly
109+ Ort ::Value input_tensor = Ort ::Value ::CreateTensor < float > (
110+ memory_info , z .data (), z .size (), input_shape .data (), input_shape .size ());
111+ // Run inference
112+ const char * input_names [ ] = {"z" };
113+ const char * output_names [ ] = {"output" };
114+ auto output_tensors = session .Run (Ort ::RunOptions {nullptr }, input_names , & input_tensor , 1 , output_names , 1 );
115+
116+ // Extract output
117+ float * output_data = output_tensors .front ().GetTensorMutableData < float > ();
118+ TVectorD output (10 );
119+ for (int i = 0 ; i < 10 ; ++ i )
120+ {
121+ output [i ] = output_data [i ];
122+ }
123+
124+ return output ;
125+ }
126+
127+ private :
128+ Ort ::Env env ;
129+ Ort ::Session session ;
130+ TRandom3 rand_gen ;
131+ };
132+
133+ namespace o2
134+ {
135+ namespace eventgen
136+ {
137+
138+ class GenTPCLoopers : public Generator
139+ {
140+ public :
141+ GenTPCLoopers (std ::string model = "tpcloopmodel.onnx" , std ::string poisson = "poisson.csv" , std ::string scaler = "scaler.json" )
142+ {
143+ // Checking if the model file exists and it's not empty
144+ std ::ifstream model_file (model );
145+ if (!model_file .is_open () || model_file .peek () == std ::ifstream ::traits_type ::eof ())
146+ {
147+ LOG (fatal ) << "Error: Model file is empty or does not exist!" ;
148+ }
149+ // Checking if the scaler file exists and it's not empty
150+ std ::ifstream scaler_file (scaler );
151+ if (!scaler_file .is_open () || scaler_file .peek () == std ::ifstream ::traits_type ::eof ())
152+ {
153+ LOG (fatal ) << "Error: Scaler file is empty or does not exist!" ;
154+ }
155+ // Checking if the poisson file exists and it's not empty
156+ if (poisson != "" )
157+ {
158+ std ::ifstream poisson_file (poisson );
159+ if (!poisson_file .is_open () || poisson_file .peek () == std ::ifstream ::traits_type ::eof ())
160+ {
161+ LOG (fatal ) << "Error: Poisson file is empty or does not exist!" ;
162+ exit (1 );
163+ } else {
164+ poisson_file >> mPoisson [0 ] >> mPoisson [1 ] >> mPoisson [2 ];
165+ poisson_file .close ();
166+ mPoissonSet = true;
167+ }
168+
169+ }
170+ mONNX = std ::make_unique < ONNXGenerator > (model );
171+ mScaler = std ::make_unique < Scaler > ( );
172+ mScaler -> load (scaler );
173+ Generator ::setTimeUnit (1.0 );
174+ Generator ::setPositionUnit (1.0 );
175+ }
176+
177+ Bool_t generateEvent () override
178+ {
179+ // Clear the vector of pairs
180+ mGenPairs .clear ();
181+ // Set number of loopers if poissonian params are available
182+ if (mPoissonSet )
183+ {
184+ mNLoopers = PoissonPairs ();
185+ }
186+ // Generate pairs of loopers
187+ for (int i = 0 ; i < mNLoopers ; ++ i )
188+ {
189+ TVectorD pair = mONNX -> generate_sample ();
190+ // Apply the inverse transformation using the scaler
191+ TVectorD transformed_pair = mScaler -> inverse_transform (pair );
192+ mGenPairs .push_back (transformed_pair );
193+ }
194+ return true;
195+ }
196+
197+ Bool_t importParticles () override
198+ {
199+ // Get looper pairs from the event
200+ for (auto & pair : mGenPairs )
201+ {
202+ double px_e , py_e , pz_e , px_p , py_p , pz_p ;
203+ double vx , vy , vz , time ;
204+ double e_etot , p_etot ;
205+ px_e = pair [0 ];
206+ py_e = pair [1 ];
207+ pz_e = pair [2 ];
208+ px_p = pair [3 ];
209+ py_p = pair [4 ];
210+ pz_p = pair [5 ];
211+ vx = pair [6 ];
212+ vy = pair [7 ];
213+ vz = pair [8 ];
214+ time = pair [9 ];
215+ e_etot = TMath ::Sqrt (px_e * px_e + py_e * py_e + pz_e * pz_e + mMass_e * mMass_e );
216+ p_etot = TMath ::Sqrt (px_p * px_p + py_p * py_p + pz_p * pz_p + mMass_p * mMass_p );
217+ // Push the electron
218+ TParticle electron (11 , 1 , -1 , -1 , -1 , -1 , px_e , py_e , pz_e , e_etot , vx , vy , vz , time / 1e9 );
219+ electron .SetStatusCode (o2 ::mcgenstatus ::MCGenStatusEncoding (electron .GetStatusCode (), 0 ).fullEncoding );
220+ electron .SetBit (ParticleStatus ::kToBeDone , //
221+ o2 ::mcgenstatus ::getHepMCStatusCode (electron .GetStatusCode ()) == 1 );
222+ mParticles .push_back (electron );
223+ // Push the positron
224+ TParticle positron (-11 , 1 , -1 , -1 , -1 , -1 , px_p , py_p , pz_p , p_etot , vx , vy , vz , time / 1e9 );
225+ positron .SetStatusCode (o2 ::mcgenstatus ::MCGenStatusEncoding (positron .GetStatusCode (), 0 ).fullEncoding );
226+ positron .SetBit (ParticleStatus ::kToBeDone , //
227+ o2 ::mcgenstatus ::getHepMCStatusCode (positron .GetStatusCode ()) == 1 );
228+ mParticles .push_back (positron );
229+ }
230+ return true;
231+ }
232+
233+ short int PoissonPairs ()
234+ {
235+ short int poissonValue ;
236+ do
237+ {
238+ // Generate a Poisson-distributed random number with mean mPoisson[0]
239+ poissonValue = mRandGen .Poisson (mPoisson [0 ]);
240+ } while (poissonValue < mPoisson [1 ] || poissonValue > mPoisson [2 ]); // Regenerate if out of range
241+
242+ return poissonValue ;
243+ }
244+
245+ void SetNLoopers (short int nsig )
246+ {
247+ if (mPoissonSet ) {
248+ LOG (warn ) << "Poissonian parameters correctly set, ignoring SetNLoopers." ;
249+ } else {
250+ mNLoopers = nsig ;
251+ }
252+ }
253+
254+ private :
255+ std ::unique_ptr < ONNXGenerator > mONNX = nullptr ;
256+ std ::unique_ptr < Scaler > mScaler = nullptr ;
257+ double mPoisson [3 ] = {0.0 , 0.0 , 0.0 }; // Mu, Min and Max of Poissonian
258+ std ::vector < TVectorD > mGenPairs ;
259+ short int mNLoopers = -1 ;
260+ bool mPoissonSet = false;
261+ // Poissonian random number generator
262+ TRandom3 mRandGen ;
263+ // Masses of the electrons and positrons
264+ TDatabasePDG * mPDG = TDatabasePDG ::Instance ();
265+ double mMass_e = mPDG -> GetParticle (11 )-> Mass ();
266+ double mMass_p = mPDG -> GetParticle (-11 )-> Mass ();
267+ };
268+
269+ } // namespace eventgen
270+ } // namespace o2
271+
272+ FairGenerator *
273+ Generator_TPCLoopers (std ::string model = "tpcloopmodel.onnx" , std ::string poisson = "poisson.csv" , std ::string scaler = "scaler.json" , short int nloopers = 1 )
274+ {
275+ // Expand all environment paths
276+ model = gSystem -> ExpandPathName (model .c_str ());
277+ poisson = gSystem -> ExpandPathName (poisson .c_str ());
278+ scaler = gSystem -> ExpandPathName (scaler .c_str ());
279+ auto generator = new o2 ::eventgen ::GenTPCLoopers (model , poisson , scaler );
280+ generator -> SetNLoopers (nloopers );
281+ return generator ;
282+ }
0 commit comments