Skip to content

Commit 71b020d

Browse files
authored
Edited Ort::env declaration for multi-models system + Distribution multipliers (#1999)
1 parent 36f0bde commit 71b020d

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

MC/config/common/external/generator/TPCLoopers.C

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include <vector>
44
#include <fstream>
55
#include <rapidjson/document.h>
6-
#include <TMatrixT.h>
6+
7+
// Static Ort::Env instance for multiple onnx model loading
8+
static Ort::Env global_env(ORT_LOGGING_LEVEL_WARNING, "GlobalEnv");
79

810
// This class is responsible for loading the scaler parameters from a JSON file
911
// and applying the inverse transformation to the generated data.
@@ -69,8 +71,8 @@ private:
6971
class ONNXGenerator
7072
{
7173
public:
72-
ONNXGenerator(const std::string &model_path)
73-
: env(ORT_LOGGING_LEVEL_WARNING, "ONNXGenerator"), session(env, model_path.c_str(), Ort::SessionOptions{})
74+
ONNXGenerator(Ort::Env &shared_env, const std::string &model_path)
75+
: env(shared_env), session(env, model_path.c_str(), Ort::SessionOptions{})
7476
{
7577
// Create session options
7678
Ort::SessionOptions session_options;
@@ -114,7 +116,7 @@ public:
114116
}
115117

116118
private:
117-
Ort::Env env;
119+
Ort::Env &env;
118120
Ort::Session session;
119121
TRandom3 rand_gen;
120122
};
@@ -195,10 +197,10 @@ class GenTPCLoopers : public Generator
195197
mGaussSet = true;
196198
}
197199
}
198-
mONNX_pair = std::make_unique<ONNXGenerator>(model_pairs);
200+
mONNX_pair = std::make_unique<ONNXGenerator>(global_env, model_pairs);
199201
mScaler_pair = std::make_unique<Scaler>();
200202
mScaler_pair->load(scaler_pair);
201-
mONNX_compton = std::make_unique<ONNXGenerator>(model_compton);
203+
mONNX_compton = std::make_unique<ONNXGenerator>(global_env, model_compton);
202204
mScaler_compton = std::make_unique<Scaler>();
203205
mScaler_compton->load(scaler_compton);
204206
Generator::setTimeUnit(1.0);
@@ -214,11 +216,11 @@ class GenTPCLoopers : public Generator
214216
// Set number of loopers if poissonian params are available
215217
if (mPoissonSet)
216218
{
217-
mNLoopersPairs = PoissonPairs();
219+
mNLoopersPairs = static_cast<short int>(std::round(mMultiplier[0] * PoissonPairs()));
218220
}
219221
if (mGaussSet)
220222
{
221-
mNLoopersCompton = GaussianElectrons();
223+
mNLoopersCompton = static_cast<short int>(std::round(mMultiplier[1] * GaussianElectrons()));
222224
}
223225
// Generate pairs
224226
for (int i = 0; i < mNLoopersPairs; ++i)
@@ -321,7 +323,7 @@ class GenTPCLoopers : public Generator
321323
return gaussValue;
322324
}
323325

324-
void SetNLoopers(short int nsig_pair, short int nsig_compton)
326+
void SetNLoopers(short int &nsig_pair, short int &nsig_compton)
325327
{
326328
if(mPoissonSet) {
327329
LOG(info) << "Poissonian parameters correctly loaded.";
@@ -335,6 +337,21 @@ class GenTPCLoopers : public Generator
335337
}
336338
}
337339

340+
void SetMultiplier(std::array<float, 2> &mult)
341+
{
342+
// Multipliers will work only if the poissonian and gaussian parameters are set
343+
// otherwise they will be ignored
344+
if (mult[0] < 0 || mult[1] < 0)
345+
{
346+
LOG(fatal) << "Error: Multiplier values must be non-negative!";
347+
exit(1);
348+
} else {
349+
LOG(info) << "Multiplier values set to: Pair = " << mult[0] << ", Compton = " << mult[1];
350+
mMultiplier[0] = mult[0];
351+
mMultiplier[1] = mult[1];
352+
}
353+
}
354+
338355
private:
339356
std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
340357
std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
@@ -346,6 +363,7 @@ class GenTPCLoopers : public Generator
346363
std::vector<std::vector<double>> mGenElectrons;
347364
short int mNLoopersPairs = -1;
348365
short int mNLoopersCompton = -1;
366+
std::array<float, 2> mMultiplier = {1., 1.};
349367
bool mPoissonSet = false;
350368
bool mGaussSet = false;
351369
// Random number generator
@@ -362,7 +380,8 @@ class GenTPCLoopers : public Generator
362380
FairGenerator *
363381
Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
364382
std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
365-
std::string scaler_compton = "scaler_compton.json", short int nloopers_pairs = 1, short int nloopers_compton = 1)
383+
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, short int nloopers_pairs = 1,
384+
short int nloopers_compton = 1)
366385
{
367386
// Expand all environment paths
368387
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
@@ -373,5 +392,6 @@ FairGenerator *
373392
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
374393
auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, poisson, gauss, scaler_pair, scaler_compton);
375394
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
395+
generator->SetMultiplier(mult);
376396
return generator;
377397
}

0 commit comments

Comments
 (0)