Skip to content

Commit 1c885da

Browse files
committed
First implementation of TPC loopers external generator
1 parent 8e6e7ab commit 1c885da

File tree

7 files changed

+370
-0
lines changed

7 files changed

+370
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"normal": {
3+
"min": [
4+
-0.0073022879660129,
5+
-0.0077305701561272,
6+
-0.0076750442385673,
7+
-0.0082916170358657,
8+
-0.0079681202769279,
9+
-0.0077468422241508,
10+
-255.6164093017578,
11+
-252.9441680908203
12+
],
13+
"max": [
14+
0.007688719779253,
15+
0.0077241472899913,
16+
0.0075828479602932,
17+
0.00813714787364,
18+
0.0083825681358575,
19+
0.0073839174583554,
20+
256.2904968261719,
21+
253.4925842285156
22+
]
23+
},
24+
"outlier": {
25+
"center": [
26+
-79.66580963134766,
27+
141535.640625
28+
],
29+
"scale": [
30+
250.8921127319336,
31+
222363.16015625
32+
]
33+
}
34+
}
638 KB
Binary file not shown.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3.165383056343737511e+00
2+
1.000000000000000000e+00
3+
1.200000000000000000e+01
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#include <onnxruntime_cxx_api.h>
2+
#include <iostream>
3+
#include <vector>
4+
#include <fstream>
5+
#include <rapidjson/document.h>
6+
#include <TMatrixT.h>
7+
8+
// This class is responsible for loading the scaler parameters from a JSON file
9+
// and applying the inverse transformation to the generated data.
10+
struct Scaler
11+
{
12+
TVectorD normal_min;
13+
TVectorD normal_max;
14+
TVectorD outlier_center;
15+
TVectorD outlier_scale;
16+
17+
void load(const std::string &filename)
18+
{
19+
std::ifstream file(filename);
20+
if (!file.is_open())
21+
{
22+
throw std::runtime_error("Error: Could not open scaler file!");
23+
}
24+
25+
std::string json_str((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
26+
file.close();
27+
28+
rapidjson::Document doc;
29+
doc.Parse(json_str.c_str());
30+
31+
if (doc.HasParseError())
32+
{
33+
throw std::runtime_error("Error: JSON parsing failed!");
34+
}
35+
36+
// Convert JSON arrays to TVectorD
37+
normal_min.ResizeTo(8);
38+
normal_max.ResizeTo(8);
39+
outlier_center.ResizeTo(2);
40+
outlier_scale.ResizeTo(2);
41+
42+
jsonArrayToVector(doc["normal"]["min"], normal_min);
43+
jsonArrayToVector(doc["normal"]["max"], normal_max);
44+
jsonArrayToVector(doc["outlier"]["center"], outlier_center);
45+
jsonArrayToVector(doc["outlier"]["scale"], outlier_scale);
46+
}
47+
48+
TVectorD inverse_transform(const TVectorD &input)
49+
{
50+
TVectorD normal_part(8);
51+
TVectorD outlier_part(2);
52+
53+
for (int i = 0; i < 8; ++i)
54+
{
55+
normal_part[i] = normal_min[i] + input[i] * (normal_max[i] - normal_min[i]);
56+
}
57+
58+
for (int i = 0; i < 2; ++i)
59+
{
60+
outlier_part[i] = input[8 + i] * outlier_scale[i] + outlier_center[i];
61+
}
62+
63+
TVectorD output(10);
64+
for (int i = 0; i < 8; ++i)
65+
output[i] = normal_part[i];
66+
for (int i = 0; i < 2; ++i)
67+
output[8 + i] = outlier_part[i];
68+
69+
return output;
70+
}
71+
72+
private:
73+
void jsonArrayToVector(const rapidjson::Value &jsonArray, TVectorD &vec)
74+
{
75+
for (int i = 0; i < jsonArray.Size(); ++i)
76+
{
77+
vec[i] = jsonArray[i].GetDouble();
78+
}
79+
}
80+
};
81+
82+
// This class loads the ONNX model and generates samples using it.
83+
class ONNXGenerator
84+
{
85+
public:
86+
ONNXGenerator(const std::string &model_path)
87+
: env(ORT_LOGGING_LEVEL_WARNING, "ONNXGenerator"), session(env, model_path.c_str(), Ort::SessionOptions{})
88+
{
89+
// Create session options
90+
Ort::SessionOptions session_options;
91+
session = Ort::Session(env, model_path.c_str(), session_options);
92+
}
93+
94+
TVectorD generate_sample()
95+
{
96+
Ort::AllocatorWithDefaultOptions allocator;
97+
98+
// Generate a latent vector (z)
99+
std::vector<float> z(100);
100+
for (auto &v : z)
101+
v = rand_gen.Gaus(0.0, 1.0);
102+
103+
// Prepare input tensor
104+
std::vector<int64_t> input_shape = {1, 100};
105+
// Get memory information
106+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
107+
108+
// Create input tensor correctly
109+
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
110+
memory_info, z.data(), z.size(), input_shape.data(), input_shape.size());
111+
// Run inference
112+
const char *input_names[] = {"z"};
113+
const char *output_names[] = {"output"};
114+
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1);
115+
116+
// Extract output
117+
float *output_data = output_tensors.front().GetTensorMutableData<float>();
118+
TVectorD output(10);
119+
for (int i = 0; i < 10; ++i)
120+
{
121+
output[i] = output_data[i];
122+
}
123+
124+
return output;
125+
}
126+
127+
private:
128+
Ort::Env env;
129+
Ort::Session session;
130+
TRandom3 rand_gen;
131+
};
132+
133+
namespace o2
134+
{
135+
namespace eventgen
136+
{
137+
138+
class GenTPCLoopers : public Generator
139+
{
140+
public:
141+
GenTPCLoopers(std::string model = "tpcloopmodel.onnx", std::string poisson = "poisson.csv", std::string scaler = "scaler.json")
142+
{
143+
// Checking if the model file exists and it's not empty
144+
std::ifstream model_file(model);
145+
if (!model_file.is_open() || model_file.peek() == std::ifstream::traits_type::eof())
146+
{
147+
LOG(fatal) << "Error: Model file is empty or does not exist!";
148+
}
149+
// Checking if the scaler file exists and it's not empty
150+
std::ifstream scaler_file(scaler);
151+
if (!scaler_file.is_open() || scaler_file.peek() == std::ifstream::traits_type::eof())
152+
{
153+
LOG(fatal) << "Error: Scaler file is empty or does not exist!";
154+
}
155+
// Checking if the poisson file exists and it's not empty
156+
if (poisson != "")
157+
{
158+
std::ifstream poisson_file(poisson);
159+
if (!poisson_file.is_open() || poisson_file.peek() == std::ifstream::traits_type::eof())
160+
{
161+
LOG(fatal) << "Error: Poisson file is empty or does not exist!";
162+
exit(1);
163+
} else {
164+
poisson_file >> mPoisson[0] >> mPoisson[1] >> mPoisson[2];
165+
poisson_file.close();
166+
mPoissonSet = true;
167+
}
168+
169+
}
170+
mONNX = std::make_unique<ONNXGenerator>(model);
171+
mScaler = std::make_unique<Scaler>();
172+
mScaler->load(scaler);
173+
Generator::setTimeUnit(1.0);
174+
Generator::setPositionUnit(1.0);
175+
}
176+
177+
Bool_t generateEvent() override
178+
{
179+
// Clear the vector of pairs
180+
mGenPairs.clear();
181+
// Set number of loopers if poissonian params are available
182+
if (mPoissonSet)
183+
{
184+
mNLoopers = PoissonPairs();
185+
}
186+
// Generate pairs of loopers
187+
for (int i = 0; i < mNLoopers; ++i)
188+
{
189+
TVectorD pair = mONNX->generate_sample();
190+
// Apply the inverse transformation using the scaler
191+
TVectorD transformed_pair = mScaler->inverse_transform(pair);
192+
mGenPairs.push_back(transformed_pair);
193+
}
194+
return true;
195+
}
196+
197+
Bool_t importParticles() override
198+
{
199+
// Get looper pairs from the event
200+
for (auto &pair : mGenPairs)
201+
{
202+
double px_e, py_e, pz_e, px_p, py_p, pz_p;
203+
double vx, vy, vz, time;
204+
double e_etot, p_etot;
205+
px_e = pair[0];
206+
py_e = pair[1];
207+
pz_e = pair[2];
208+
px_p = pair[3];
209+
py_p = pair[4];
210+
pz_p = pair[5];
211+
vx = pair[6];
212+
vy = pair[7];
213+
vz = pair[8];
214+
time = pair[9];
215+
e_etot = TMath::Sqrt(px_e * px_e + py_e * py_e + pz_e * pz_e + mMass_e * mMass_e);
216+
p_etot = TMath::Sqrt(px_p * px_p + py_p * py_p + pz_p * pz_p + mMass_p * mMass_p);
217+
// Push the electron
218+
TParticle electron(11, 1, -1, -1, -1, -1, px_e, py_e, pz_e, e_etot, vx, vy, vz, time / 1e9);
219+
electron.SetStatusCode(o2::mcgenstatus::MCGenStatusEncoding(electron.GetStatusCode(), 0).fullEncoding);
220+
electron.SetBit(ParticleStatus::kToBeDone, //
221+
o2::mcgenstatus::getHepMCStatusCode(electron.GetStatusCode()) == 1);
222+
mParticles.push_back(electron);
223+
// Push the positron
224+
TParticle positron(-11, 1, -1, -1, -1, -1, px_p, py_p, pz_p, p_etot, vx, vy, vz, time / 1e9);
225+
positron.SetStatusCode(o2::mcgenstatus::MCGenStatusEncoding(positron.GetStatusCode(), 0).fullEncoding);
226+
positron.SetBit(ParticleStatus::kToBeDone, //
227+
o2::mcgenstatus::getHepMCStatusCode(positron.GetStatusCode()) == 1);
228+
mParticles.push_back(positron);
229+
}
230+
return true;
231+
}
232+
233+
short int PoissonPairs()
234+
{
235+
short int poissonValue;
236+
do
237+
{
238+
// Generate a Poisson-distributed random number with mean mPoisson[0]
239+
poissonValue = mRandGen.Poisson(mPoisson[0]);
240+
} while (poissonValue < mPoisson[1] || poissonValue > mPoisson[2]); // Regenerate if out of range
241+
242+
return poissonValue;
243+
}
244+
245+
void SetNLoopers(short int nsig)
246+
{
247+
if(mPoissonSet) {
248+
LOG(warn) << "Poissonian parameters correctly set, ignoring SetNLoopers.";
249+
} else {
250+
mNLoopers = nsig;
251+
}
252+
}
253+
254+
private:
255+
std::unique_ptr<ONNXGenerator> mONNX = nullptr;
256+
std::unique_ptr<Scaler> mScaler = nullptr;
257+
double mPoisson[3] = {0.0, 0.0, 0.0}; // Mu, Min and Max of Poissonian
258+
std::vector<TVectorD> mGenPairs;
259+
short int mNLoopers = -1;
260+
bool mPoissonSet = false;
261+
// Poissonian random number generator
262+
TRandom3 mRandGen;
263+
// Masses of the electrons and positrons
264+
TDatabasePDG *mPDG = TDatabasePDG::Instance();
265+
double mMass_e = mPDG->GetParticle(11)->Mass();
266+
double mMass_p = mPDG->GetParticle(-11)->Mass();
267+
};
268+
269+
} // namespace eventgen
270+
} // namespace o2
271+
272+
FairGenerator *
273+
Generator_TPCLoopers(std::string model = "tpcloopmodel.onnx", std::string poisson = "poisson.csv", std::string scaler = "scaler.json", short int nloopers = 1)
274+
{
275+
// Expand all environment paths
276+
model = gSystem->ExpandPathName(model.c_str());
277+
poisson = gSystem->ExpandPathName(poisson.c_str());
278+
scaler = gSystem->ExpandPathName(scaler.c_str());
279+
auto generator = new o2::eventgen::GenTPCLoopers(model, poisson, scaler);
280+
generator->SetNLoopers(nloopers);
281+
return generator;
282+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Example of tpc loopers generator with a poisson distribution of pairs
2+
[GeneratorExternal]
3+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
4+
funcName = Generator_TPCLoopers("${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/generatorWGAN_pair.onnx", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/poisson_params.csv", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Example of tpc loopers generator with a fixed number of pairs (10)
2+
#---> GeneratorTPCloopers
3+
[GeneratorExternal]
4+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
5+
funcName = Generator_TPCLoopers("${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/generatorWGAN_pair.onnx", "", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json",10)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
int External() {
2+
std::string path{"o2sim_Kine.root"};
3+
TFile file(path.c_str(), "READ");
4+
if (file.IsZombie()) {
5+
std::cerr << "Cannot open ROOT file " << path << "\n";
6+
return 1;
7+
}
8+
auto tree = (TTree *)file.Get("o2sim");
9+
if (!tree) {
10+
std::cerr << "Cannot find tree 'o2sim' in file " << path << "\n";
11+
return 1;
12+
}
13+
// Get the MCTrack branch
14+
std::vector<o2::MCTrack> *tracks{};
15+
tree->SetBranchAddress("MCTrack", &tracks);
16+
// Check if only pairs are contained in the simulation
17+
int nEvents = tree->GetEntries();
18+
int count_e = 0;
19+
int count_p = 0;
20+
for (int i = 0; i < nEvents; i++) {
21+
tree->GetEntry(i);
22+
for (auto &track : *tracks)
23+
{
24+
auto pdg = track.GetPdgCode();
25+
if (pdg == 11) {
26+
count_e++;
27+
} else if (pdg == -11) {
28+
count_p++;
29+
} else {
30+
std::cerr << "Found unexpected PDG code: " << pdg << "\n";
31+
return 1;
32+
}
33+
}
34+
}
35+
if (count_e != count_p) {
36+
std::cerr << "Mismatch in number of electrons and positrons: " << count_e << " vs " << count_p << "\n";
37+
return 1;
38+
}
39+
file.Close();
40+
41+
return 0;
42+
}

0 commit comments

Comments
 (0)