Skip to content

Commit 5622cef

Browse files
authored
Flat Gas external generator (#2094)
* Flat Gas external generator
1 parent 6bbcd05 commit 5622cef

File tree

2 files changed

+191
-18
lines changed

2 files changed

+191
-18
lines changed

MC/config/common/external/generator/TPCLoopers.C

Lines changed: 186 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,61 @@ class GenTPCLoopers : public Generator
215215
mGenPairs.clear();
216216
// Clear the vector of compton electrons
217217
mGenElectrons.clear();
218-
// Set number of loopers if poissonian params are available
219-
if (mPoissonSet)
218+
if (mFlatGas)
220219
{
221-
mNLoopersPairs = static_cast<short int>(std::round(mMultiplier[0] * PoissonPairs()));
222-
}
223-
if (mGaussSet)
224-
{
225-
mNLoopersCompton = static_cast<short int>(std::round(mMultiplier[1] * GaussianElectrons()));
226-
}
220+
unsigned int nLoopers, nLoopersPairs, nLoopersCompton;
221+
LOG(debug) << "mCurrentEvent is " << mCurrentEvent;
222+
LOG(debug) << "Current event time: " << ((mCurrentEvent < mInteractionTimeRecords.size() - 1) ? std::to_string(mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns()) : std::to_string(mIntTimeRecMean)) << " ns";
223+
LOG(debug) << "Current time offset wrt BC: " << mInteractionTimeRecords[mCurrentEvent].getTimeOffsetWrtBC() << " ns";
224+
mTimeLimit = (mCurrentEvent < mInteractionTimeRecords.size() - 1) ? mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns() : mIntTimeRecMean;
225+
// With flat gas the number of loopers are adapted based on time interval widths
226+
nLoopers = mFlatGasNumber * (mTimeLimit / mIntTimeRecMean);
227+
nLoopersPairs = static_cast<unsigned int>(std::round(nLoopers * mLoopsFractionPairs));
228+
nLoopersCompton = nLoopers - nLoopersPairs;
229+
SetNLoopers(nLoopersPairs, nLoopersCompton);
230+
LOG(info) << "Flat gas loopers: " << nLoopers << " (pairs: " << nLoopersPairs << ", compton: " << nLoopersCompton << ")";
231+
generateEvent(mTimeLimit);
232+
mCurrentEvent++;
233+
} else {
234+
// Set number of loopers if poissonian params are available
235+
if (mPoissonSet)
236+
{
237+
mNLoopersPairs = static_cast<unsigned int>(std::round(mMultiplier[0] * PoissonPairs()));
238+
}
239+
if (mGaussSet)
240+
{
241+
mNLoopersCompton = static_cast<unsigned int>(std::round(mMultiplier[1] * GaussianElectrons()));
242+
}
243+
// Generate pairs
244+
for (int i = 0; i < mNLoopersPairs; ++i)
245+
{
246+
std::vector<double> pair = mONNX_pair->generate_sample();
247+
// Apply the inverse transformation using the scaler
248+
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
249+
mGenPairs.push_back(transformed_pair);
250+
}
251+
// Generate compton electrons
252+
for (int i = 0; i < mNLoopersCompton; ++i)
253+
{
254+
std::vector<double> electron = mONNX_compton->generate_sample();
255+
// Apply the inverse transformation using the scaler
256+
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
257+
mGenElectrons.push_back(transformed_electron);
258+
}
259+
}
260+
return true;
261+
}
262+
263+
Bool_t generateEvent(double &time_limit)
264+
{
265+
LOG(info) << "Time constraint for loopers: " << time_limit << " ns";
227266
// Generate pairs
228267
for (int i = 0; i < mNLoopersPairs; ++i)
229268
{
230269
std::vector<double> pair = mONNX_pair->generate_sample();
231270
// Apply the inverse transformation using the scaler
232271
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
272+
transformed_pair[9] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
233273
mGenPairs.push_back(transformed_pair);
234274
}
235275
// Generate compton electrons
@@ -238,8 +278,10 @@ class GenTPCLoopers : public Generator
238278
std::vector<double> electron = mONNX_compton->generate_sample();
239279
// Apply the inverse transformation using the scaler
240280
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
281+
transformed_electron[6] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
241282
mGenElectrons.push_back(transformed_electron);
242283
}
284+
LOG(info) << "Generated Particles with time limit";
243285
return true;
244286
}
245287

@@ -301,9 +343,9 @@ class GenTPCLoopers : public Generator
301343
return true;
302344
}
303345

304-
short int PoissonPairs()
346+
unsigned int PoissonPairs()
305347
{
306-
short int poissonValue;
348+
unsigned int poissonValue;
307349
do
308350
{
309351
// Generate a Poisson-distributed random number with mean mPoisson[0]
@@ -313,9 +355,9 @@ class GenTPCLoopers : public Generator
313355
return poissonValue;
314356
}
315357

316-
short int GaussianElectrons()
358+
unsigned int GaussianElectrons()
317359
{
318-
short int gaussValue;
360+
unsigned int gaussValue;
319361
do
320362
{
321363
// Generate a Normal-distributed random number with mean mGass[0] and stddev mGauss[1]
@@ -325,7 +367,7 @@ class GenTPCLoopers : public Generator
325367
return gaussValue;
326368
}
327369

328-
void SetNLoopers(short int &nsig_pair, short int &nsig_compton)
370+
void SetNLoopers(unsigned int &nsig_pair, unsigned int &nsig_compton)
329371
{
330372
if(mPoissonSet) {
331373
LOG(info) << "Poissonian parameters correctly loaded.";
@@ -354,6 +396,52 @@ class GenTPCLoopers : public Generator
354396
}
355397
}
356398

399+
void setFlatGas(Bool_t &flat, const Int_t &number = -1)
400+
{
401+
mFlatGas = flat;
402+
if (mFlatGas)
403+
{
404+
if (number < 0)
405+
{
406+
LOG(warn) << "Warning: Number of loopers per event must be non-negative! Switching option off.";
407+
mFlatGas = false;
408+
mFlatGasNumber = -1;
409+
} else {
410+
mFlatGasNumber = number;
411+
mContextFile = std::filesystem::exists("collisioncontext.root") ? TFile::Open("collisioncontext.root") : nullptr;
412+
mCollisionContext = mContextFile ? (o2::steer::DigitizationContext *)mContextFile->Get("DigitizationContext") : nullptr;
413+
mInteractionTimeRecords = mCollisionContext ? mCollisionContext->getEventRecords() : std::vector<o2::InteractionTimeRecord>{};
414+
if (mInteractionTimeRecords.empty())
415+
{
416+
LOG(error) << "Error: No interaction time records found in the collision context!";
417+
exit(1);
418+
} else {
419+
LOG(info) << "Interaction Time records has " << mInteractionTimeRecords.size() << " entries.";
420+
mCollisionContext->printCollisionSummary();
421+
}
422+
for (int c = 0; c < mInteractionTimeRecords.size() - 1; c++)
423+
{
424+
mIntTimeRecMean += mInteractionTimeRecords[c + 1].bc2ns() - mInteractionTimeRecords[c].bc2ns();
425+
}
426+
mIntTimeRecMean /= (mInteractionTimeRecords.size() - 1); // Average interaction time record used as reference
427+
}
428+
} else {
429+
mFlatGasNumber = -1;
430+
}
431+
LOG(info) << "Flat gas loopers: " << (mFlatGas ? "ON" : "OFF") << ", Reference loopers number per event: " << mFlatGasNumber;
432+
}
433+
434+
void setFractionPairs(float &fractionPairs)
435+
{
436+
if (fractionPairs < 0 || fractionPairs > 1)
437+
{
438+
LOG(fatal) << "Error: Loops fraction for pairs must be in the range [0, 1].";
439+
exit(1);
440+
}
441+
mLoopsFractionPairs = fractionPairs;
442+
LOG(info) << "Pairs fraction set to: " << mLoopsFractionPairs;
443+
}
444+
357445
private:
358446
std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
359447
std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
@@ -363,8 +451,8 @@ class GenTPCLoopers : public Generator
363451
double mGauss[4] = {0.0, 0.0, 0.0, 0.0}; // Mean, Std, Min, Max
364452
std::vector<std::vector<double>> mGenPairs;
365453
std::vector<std::vector<double>> mGenElectrons;
366-
short int mNLoopersPairs = -1;
367-
short int mNLoopersCompton = -1;
454+
unsigned int mNLoopersPairs = -1;
455+
unsigned int mNLoopersCompton = -1;
368456
std::array<float, 2> mMultiplier = {1., 1.};
369457
bool mPoissonSet = false;
370458
bool mGaussSet = false;
@@ -374,6 +462,15 @@ class GenTPCLoopers : public Generator
374462
TDatabasePDG *mPDG = TDatabasePDG::Instance();
375463
double mMass_e = mPDG->GetParticle(11)->Mass();
376464
double mMass_p = mPDG->GetParticle(-11)->Mass();
465+
int mCurrentEvent = 0; // Current event number, used for adaptive loopers
466+
TFile *mContextFile = nullptr; // Input collision context file
467+
o2::steer::DigitizationContext *mCollisionContext = nullptr; // Pointer to the digitization context
468+
std::vector<o2::InteractionTimeRecord> mInteractionTimeRecords; // Interaction time records from collision context
469+
Bool_t mFlatGas = false; // Flag to indicate if flat gas loopers are used
470+
Int_t mFlatGasNumber = -1; // Number of flat gas loopers per event
471+
double mIntTimeRecMean = 1.0; // Average interaction time record used for the reference
472+
double mTimeLimit = 0.0; // Time limit for the current event
473+
float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
377474
};
378475

379476
} // namespace eventgen
@@ -387,8 +484,8 @@ class GenTPCLoopers : public Generator
387484
FairGenerator *
388485
Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
389486
std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
390-
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, short int nloopers_pairs = 1,
391-
short int nloopers_compton = 1)
487+
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, unsigned int nloopers_pairs = 1,
488+
unsigned int nloopers_compton = 1)
392489
{
393490
// Expand all environment paths
394491
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
@@ -450,4 +547,75 @@ FairGenerator *
450547
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
451548
generator->SetMultiplier(mult);
452549
return generator;
453-
}
550+
}
551+
552+
// Generator with flat gas loopers. Number of loopers starts from a reference value and changes
553+
// based on the BC time intervals in each event.
554+
FairGenerator *
555+
Generator_TPCLoopersFlat(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
556+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json",
557+
bool flat_gas = true, const int loops_num = 500, float fraction_pairs = 0.08)
558+
{
559+
// Expand all environment paths
560+
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
561+
model_compton = gSystem->ExpandPathName(model_compton.c_str());
562+
scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str());
563+
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
564+
const std::array<std::string, 2> models = {model_pairs, model_compton};
565+
const std::array<std::string, 2> local_names = {"WGANpair.onnx", "WGANcompton.onnx"};
566+
const std::array<bool, 2> isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")};
567+
const std::array<bool, 2> isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")};
568+
if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v)
569+
{ return v; }))
570+
{
571+
if (!gGrid)
572+
{
573+
TGrid::Connect("alien://");
574+
if (!gGrid)
575+
{
576+
LOG(fatal) << "AliEn connection failed, check token.";
577+
exit(1);
578+
}
579+
}
580+
for (size_t i = 0; i < models.size(); ++i)
581+
{
582+
if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str()))
583+
{
584+
LOG(fatal) << "Error: Model file " << models[i] << " does not exist!";
585+
exit(1);
586+
}
587+
}
588+
}
589+
if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v)
590+
{ return v; }))
591+
{
592+
o2::ccdb::CcdbApi ccdb_api;
593+
ccdb_api.init("http://alice-ccdb.cern.ch");
594+
for (size_t i = 0; i < models.size(); ++i)
595+
{
596+
if (isCCDB[i])
597+
{
598+
auto model_path = models[i].substr(7); // Remove "ccdb://"
599+
// Treat filename if provided in the CCDB path
600+
auto extension = model_path.find(".onnx");
601+
if (extension != std::string::npos)
602+
{
603+
auto last_slash = model_path.find_last_of('/');
604+
model_path = model_path.substr(0, last_slash);
605+
}
606+
std::map<std::string, std::string> filter;
607+
if (!ccdb_api.retrieveBlob(model_path, "./", filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str()))
608+
{
609+
LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!";
610+
exit(1);
611+
}
612+
}
613+
}
614+
}
615+
model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs;
616+
model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton;
617+
auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, "", "", scaler_pair, scaler_compton);
618+
generator->setFractionPairs(fraction_pairs);
619+
generator->setFlatGas(flat_gas, loops_num);
620+
return generator;
621+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# TPC loopers injector
2+
#---> GeneratorTPCloopers
3+
[GeneratorExternal]
4+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
5+
funcName = Generator_TPCLoopersFlat("ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json")

0 commit comments

Comments
 (0)