Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 186 additions & 18 deletions MC/config/common/external/generator/TPCLoopers.C
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,61 @@ class GenTPCLoopers : public Generator
mGenPairs.clear();
// Clear the vector of compton electrons
mGenElectrons.clear();
// Set number of loopers if poissonian params are available
if (mPoissonSet)
if (mFlatGas)
{
mNLoopersPairs = static_cast<short int>(std::round(mMultiplier[0] * PoissonPairs()));
}
if (mGaussSet)
{
mNLoopersCompton = static_cast<short int>(std::round(mMultiplier[1] * GaussianElectrons()));
}
unsigned int nLoopers, nLoopersPairs, nLoopersCompton;
LOG(debug) << "mCurrentEvent is " << mCurrentEvent;
LOG(debug) << "Current event time: " << ((mCurrentEvent < mInteractionTimeRecords.size() - 1) ? std::to_string(mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns()) : std::to_string(mIntTimeRecMean)) << " ns";
LOG(debug) << "Current time offset wrt BC: " << mInteractionTimeRecords[mCurrentEvent].getTimeOffsetWrtBC() << " ns";
mTimeLimit = (mCurrentEvent < mInteractionTimeRecords.size() - 1) ? mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns() : mIntTimeRecMean;
// With flat gas the number of loopers are adapted based on time interval widths
nLoopers = mFlatGasNumber * (mTimeLimit / mIntTimeRecMean);
nLoopersPairs = static_cast<unsigned int>(std::round(nLoopers * mLoopsFractionPairs));
nLoopersCompton = nLoopers - nLoopersPairs;
SetNLoopers(nLoopersPairs, nLoopersCompton);
LOG(info) << "Flat gas loopers: " << nLoopers << " (pairs: " << nLoopersPairs << ", compton: " << nLoopersCompton << ")";
generateEvent(mTimeLimit);
mCurrentEvent++;
} else {
// Set number of loopers if poissonian params are available
if (mPoissonSet)
{
mNLoopersPairs = static_cast<unsigned int>(std::round(mMultiplier[0] * PoissonPairs()));
}
if (mGaussSet)
{
mNLoopersCompton = static_cast<unsigned int>(std::round(mMultiplier[1] * GaussianElectrons()));
}
// Generate pairs
for (int i = 0; i < mNLoopersPairs; ++i)
{
std::vector<double> pair = mONNX_pair->generate_sample();
// Apply the inverse transformation using the scaler
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
mGenPairs.push_back(transformed_pair);
}
// Generate compton electrons
for (int i = 0; i < mNLoopersCompton; ++i)
{
std::vector<double> electron = mONNX_compton->generate_sample();
// Apply the inverse transformation using the scaler
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
mGenElectrons.push_back(transformed_electron);
}
}
return true;
}

Bool_t generateEvent(double &time_limit)
{
LOG(info) << "Time constraint for loopers: " << time_limit << " ns";
// Generate pairs
for (int i = 0; i < mNLoopersPairs; ++i)
{
std::vector<double> pair = mONNX_pair->generate_sample();
// Apply the inverse transformation using the scaler
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
transformed_pair[9] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
mGenPairs.push_back(transformed_pair);
}
// Generate compton electrons
Expand All @@ -238,8 +278,10 @@ class GenTPCLoopers : public Generator
std::vector<double> electron = mONNX_compton->generate_sample();
// Apply the inverse transformation using the scaler
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
transformed_electron[6] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
mGenElectrons.push_back(transformed_electron);
}
LOG(info) << "Generated Particles with time limit";
return true;
}

Expand Down Expand Up @@ -301,9 +343,9 @@ class GenTPCLoopers : public Generator
return true;
}

short int PoissonPairs()
unsigned int PoissonPairs()
{
short int poissonValue;
unsigned int poissonValue;
do
{
// Generate a Poisson-distributed random number with mean mPoisson[0]
Expand All @@ -313,9 +355,9 @@ class GenTPCLoopers : public Generator
return poissonValue;
}

short int GaussianElectrons()
unsigned int GaussianElectrons()
{
short int gaussValue;
unsigned int gaussValue;
do
{
// Generate a Normal-distributed random number with mean mGass[0] and stddev mGauss[1]
Expand All @@ -325,7 +367,7 @@ class GenTPCLoopers : public Generator
return gaussValue;
}

void SetNLoopers(short int &nsig_pair, short int &nsig_compton)
void SetNLoopers(unsigned int &nsig_pair, unsigned int &nsig_compton)
{
if(mPoissonSet) {
LOG(info) << "Poissonian parameters correctly loaded.";
Expand Down Expand Up @@ -354,6 +396,52 @@ class GenTPCLoopers : public Generator
}
}

void setFlatGas(Bool_t &flat, const Int_t &number = -1)
{
mFlatGas = flat;
if (mFlatGas)
{
if (number < 0)
{
LOG(warn) << "Warning: Number of loopers per event must be non-negative! Switching option off.";
mFlatGas = false;
mFlatGasNumber = -1;
} else {
mFlatGasNumber = number;
mContextFile = std::filesystem::exists("collisioncontext.root") ? TFile::Open("collisioncontext.root") : nullptr;
mCollisionContext = mContextFile ? (o2::steer::DigitizationContext *)mContextFile->Get("DigitizationContext") : nullptr;
mInteractionTimeRecords = mCollisionContext ? mCollisionContext->getEventRecords() : std::vector<o2::InteractionTimeRecord>{};
if (mInteractionTimeRecords.empty())
{
LOG(error) << "Error: No interaction time records found in the collision context!";
exit(1);
} else {
LOG(info) << "Interaction Time records has " << mInteractionTimeRecords.size() << " entries.";
mCollisionContext->printCollisionSummary();
}
for (int c = 0; c < mInteractionTimeRecords.size() - 1; c++)
{
mIntTimeRecMean += mInteractionTimeRecords[c + 1].bc2ns() - mInteractionTimeRecords[c].bc2ns();
}
mIntTimeRecMean /= (mInteractionTimeRecords.size() - 1); // Average interaction time record used as reference
}
} else {
mFlatGasNumber = -1;
}
LOG(info) << "Flat gas loopers: " << (mFlatGas ? "ON" : "OFF") << ", Reference loopers number per event: " << mFlatGasNumber;
}

void setFractionPairs(float &fractionPairs)
{
if (fractionPairs < 0 || fractionPairs > 1)
{
LOG(fatal) << "Error: Loops fraction for pairs must be in the range [0, 1].";
exit(1);
}
mLoopsFractionPairs = fractionPairs;
LOG(info) << "Pairs fraction set to: " << mLoopsFractionPairs;
}

private:
std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
Expand All @@ -363,8 +451,8 @@ class GenTPCLoopers : public Generator
double mGauss[4] = {0.0, 0.0, 0.0, 0.0}; // Mean, Std, Min, Max
std::vector<std::vector<double>> mGenPairs;
std::vector<std::vector<double>> mGenElectrons;
short int mNLoopersPairs = -1;
short int mNLoopersCompton = -1;
unsigned int mNLoopersPairs = -1;
unsigned int mNLoopersCompton = -1;
std::array<float, 2> mMultiplier = {1., 1.};
bool mPoissonSet = false;
bool mGaussSet = false;
Expand All @@ -374,6 +462,15 @@ class GenTPCLoopers : public Generator
TDatabasePDG *mPDG = TDatabasePDG::Instance();
double mMass_e = mPDG->GetParticle(11)->Mass();
double mMass_p = mPDG->GetParticle(-11)->Mass();
int mCurrentEvent = 0; // Current event number, used for adaptive loopers
TFile *mContextFile = nullptr; // Input collision context file
o2::steer::DigitizationContext *mCollisionContext = nullptr; // Pointer to the digitization context
std::vector<o2::InteractionTimeRecord> mInteractionTimeRecords; // Interaction time records from collision context
Bool_t mFlatGas = false; // Flag to indicate if flat gas loopers are used
Int_t mFlatGasNumber = -1; // Number of flat gas loopers per event
double mIntTimeRecMean = 1.0; // Average interaction time record used for the reference
double mTimeLimit = 0.0; // Time limit for the current event
float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
};

} // namespace eventgen
Expand All @@ -387,8 +484,8 @@ class GenTPCLoopers : public Generator
FairGenerator *
Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, short int nloopers_pairs = 1,
short int nloopers_compton = 1)
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, unsigned int nloopers_pairs = 1,
unsigned int nloopers_compton = 1)
{
// Expand all environment paths
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
Expand Down Expand Up @@ -450,4 +547,75 @@ FairGenerator *
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
generator->SetMultiplier(mult);
return generator;
}
}

// Generator with flat gas loopers. Number of loopers starts from a reference value and changes
// based on the BC time intervals in each event.
FairGenerator *
Generator_TPCLoopersFlat(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json",
bool flat_gas = true, const int loops_num = 500, float fraction_pairs = 0.08)
{
// Expand all environment paths
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
model_compton = gSystem->ExpandPathName(model_compton.c_str());
scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str());
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
const std::array<std::string, 2> models = {model_pairs, model_compton};
const std::array<std::string, 2> local_names = {"WGANpair.onnx", "WGANcompton.onnx"};
const std::array<bool, 2> isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")};
const std::array<bool, 2> isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")};
if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v)
{ return v; }))
{
if (!gGrid)
{
TGrid::Connect("alien://");
if (!gGrid)
{
LOG(fatal) << "AliEn connection failed, check token.";
exit(1);
}
}
for (size_t i = 0; i < models.size(); ++i)
{
if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str()))
{
LOG(fatal) << "Error: Model file " << models[i] << " does not exist!";
exit(1);
}
}
}
if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v)
{ return v; }))
{
o2::ccdb::CcdbApi ccdb_api;
ccdb_api.init("http://alice-ccdb.cern.ch");
for (size_t i = 0; i < models.size(); ++i)
{
if (isCCDB[i])
{
auto model_path = models[i].substr(7); // Remove "ccdb://"
// Treat filename if provided in the CCDB path
auto extension = model_path.find(".onnx");
if (extension != std::string::npos)
{
auto last_slash = model_path.find_last_of('/');
model_path = model_path.substr(0, last_slash);
}
std::map<std::string, std::string> filter;
if (!ccdb_api.retrieveBlob(model_path, "./", filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str()))
{
LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!";
exit(1);
}
}
}
}
model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs;
model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton;
auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, "", "", scaler_pair, scaler_compton);
generator->setFractionPairs(fraction_pairs);
generator->setFlatGas(flat_gas, loops_num);
return generator;
}
5 changes: 5 additions & 0 deletions MC/config/common/ini/GeneratorLoopersFlatGas.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TPC loopers injector
#---> GeneratorTPCloopers
[GeneratorExternal]
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
funcName = Generator_TPCLoopersFlat("ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json")
Loading