Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions MC/config/common/external/generator/TPCLoopers.C
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include <vector>
#include <fstream>
#include <rapidjson/document.h>
#include <TMatrixT.h>

// Static Ort::Env instance for multiple onnx model loading
static Ort::Env global_env(ORT_LOGGING_LEVEL_WARNING, "GlobalEnv");

// This class is responsible for loading the scaler parameters from a JSON file
// and applying the inverse transformation to the generated data.
Expand Down Expand Up @@ -69,8 +71,8 @@ private:
class ONNXGenerator
{
public:
ONNXGenerator(const std::string &model_path)
: env(ORT_LOGGING_LEVEL_WARNING, "ONNXGenerator"), session(env, model_path.c_str(), Ort::SessionOptions{})
ONNXGenerator(Ort::Env &shared_env, const std::string &model_path)
: env(shared_env), session(env, model_path.c_str(), Ort::SessionOptions{})
{
// Create session options
Ort::SessionOptions session_options;
Expand Down Expand Up @@ -114,7 +116,7 @@ public:
}

private:
Ort::Env env;
Ort::Env &env;
Ort::Session session;
TRandom3 rand_gen;
};
Expand Down Expand Up @@ -195,10 +197,10 @@ class GenTPCLoopers : public Generator
mGaussSet = true;
}
}
mONNX_pair = std::make_unique<ONNXGenerator>(model_pairs);
mONNX_pair = std::make_unique<ONNXGenerator>(global_env, model_pairs);
mScaler_pair = std::make_unique<Scaler>();
mScaler_pair->load(scaler_pair);
mONNX_compton = std::make_unique<ONNXGenerator>(model_compton);
mONNX_compton = std::make_unique<ONNXGenerator>(global_env, model_compton);
mScaler_compton = std::make_unique<Scaler>();
mScaler_compton->load(scaler_compton);
Generator::setTimeUnit(1.0);
Expand All @@ -214,11 +216,11 @@ class GenTPCLoopers : public Generator
// Set number of loopers if poissonian params are available
if (mPoissonSet)
{
mNLoopersPairs = PoissonPairs();
mNLoopersPairs = static_cast<short int>(std::round(mMultiplier[0] * PoissonPairs()));
}
if (mGaussSet)
{
mNLoopersCompton = GaussianElectrons();
mNLoopersCompton = static_cast<short int>(std::round(mMultiplier[1] * GaussianElectrons()));
}
// Generate pairs
for (int i = 0; i < mNLoopersPairs; ++i)
Expand Down Expand Up @@ -321,7 +323,7 @@ class GenTPCLoopers : public Generator
return gaussValue;
}

void SetNLoopers(short int nsig_pair, short int nsig_compton)
void SetNLoopers(short int &nsig_pair, short int &nsig_compton)
{
if(mPoissonSet) {
LOG(info) << "Poissonian parameters correctly loaded.";
Expand All @@ -335,6 +337,21 @@ class GenTPCLoopers : public Generator
}
}

void SetMultiplier(std::array<float, 2> &mult)
{
// Multipliers will work only if the poissonian and gaussian parameters are set
// otherwise they will be ignored
if (mult[0] < 0 || mult[1] < 0)
{
LOG(fatal) << "Error: Multiplier values must be non-negative!";
exit(1);
} else {
LOG(info) << "Multiplier values set to: Pair = " << mult[0] << ", Compton = " << mult[1];
mMultiplier[0] = mult[0];
mMultiplier[1] = mult[1];
}
}

private:
std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
Expand All @@ -346,6 +363,7 @@ class GenTPCLoopers : public Generator
std::vector<std::vector<double>> mGenElectrons;
short int mNLoopersPairs = -1;
short int mNLoopersCompton = -1;
std::array<float, 2> mMultiplier = {1., 1.};
bool mPoissonSet = false;
bool mGaussSet = false;
// Random number generator
Expand All @@ -362,7 +380,8 @@ class GenTPCLoopers : public Generator
FairGenerator *
Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
std::string scaler_compton = "scaler_compton.json", short int nloopers_pairs = 1, short int nloopers_compton = 1)
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, short int nloopers_pairs = 1,
short int nloopers_compton = 1)
{
// Expand all environment paths
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
Expand All @@ -373,5 +392,6 @@ FairGenerator *
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, poisson, gauss, scaler_pair, scaler_compton);
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
generator->SetMultiplier(mult);
return generator;
}