Skip to content

Commit a3e1777

Browse files
Merge branch 'AliceO2Group:master' into master
2 parents d24e5db + b0c092c commit a3e1777

File tree

124 files changed

+7123
-5744
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

124 files changed

+7123
-5744
lines changed

ALICE3/Tasks/alice3-multicharm.cxx

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ struct alice3multicharm {
8484
std::string prefix = "bdt"; // JSON group name
8585
Configurable<std::string> ccdbUrl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
8686
Configurable<std::string> localPath{"localPath", "MCharm_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
87-
Configurable<std::string> pathCCDB{"btdPathCCDB", "Users/j/jekarlss/MLModels2", "Path on CCDB"};
88-
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
87+
Configurable<std::string> pathCCDB{"btdPathCCDB", "Users/j/jekarlss/MLModels", "Path on CCDB"};
88+
Configurable<int64_t> timestampCCDB{"timestampCCDB", 1695750420200, "timestamp of the ONNX file for ML model used to query in CCDB. Please use 1695750420200"};
8989
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
9090
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
9191
Configurable<bool> enableML{"enableML", false, "Enables bdt model"};
9292
Configurable<std::vector<float>> requiredScores{"requiredScores", {0.5, 0.75, 0.85, 0.9, 0.95, 0.99}, "Vector of different scores to try"};
9393
} bdt;
9494

9595
ConfigurableAxis axisEta{"axisEta", {80, -4.0f, +4.0f}, "#eta"};
96-
ConfigurableAxis axisXicMass{"axisXicMass", {200, 2.368f, 2.568f}, "XiC Inv Mass (GeV/c^{2})"};
96+
ConfigurableAxis axisXicMass{"axisXicMass", {200, 2.368f, 2.568f}, "Xic Inv Mass (GeV/c^{2})"};
9797
ConfigurableAxis axisXiccMass{"axisXiccMass", {200, 3.521f, 3.721f}, "Xicc Inv Mass (GeV/c^{2})"};
9898
ConfigurableAxis axisDCA{"axisDCA", {400, 0, 400}, "DCA (#mum)"};
9999
ConfigurableAxis axisRadiusLarge{"axisRadiusLarge", {1000, 0, 20}, "Decay radius (cm)"};
@@ -102,6 +102,7 @@ struct alice3multicharm {
102102
ConfigurableAxis axisNSigma{"axisNSigma", {21, -10, 10}, "nsigma"};
103103
ConfigurableAxis axisDecayLength{"axisDecayLength", {2000, 0, 2000}, "Decay lenght (#mum)"};
104104
ConfigurableAxis axisDcaDaughters{"axisDcaDaughters", {200, 0, 100}, "DCA (mum)"};
105+
ConfigurableAxis axisBDTScore{"axisBDTScore", {100, 0, 1}, "BDT Score"};
105106
ConfigurableAxis axisPt{"axisPt", {VARIABLE_WIDTH, 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.0f, 2.2f, 2.4f, 2.6f, 2.8f, 3.0f, 3.2f, 3.4f, 3.6f, 3.8f, 4.0f, 4.4f, 4.8f, 5.2f, 5.6f, 6.0f, 6.5f, 7.0f, 7.5f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 17.0f, 19.0f, 21.0f, 23.0f, 25.0f, 30.0f, 35.0f, 40.0f, 50.0f}, "pt axis for QA histograms"};
106107

107108
Configurable<float> xiMinDCAxy{"xiMinDCAxy", -1, "[0] in |DCAxy| > [0]+[1]/pT"};
@@ -133,21 +134,6 @@ struct alice3multicharm {
133134

134135
void init(InitContext&)
135136
{
136-
ccdb->setURL(bdt.ccdbUrl.value);
137-
if (bdt.loadModelsFromCCDB) {
138-
ccdbApi.init(bdt.ccdbUrl);
139-
LOG(info) << "Fetching model for timestamp: " << bdt.timestampCCDB.value;
140-
bool retrieveSuccessMCharm = ccdbApi.retrieveBlob(bdt.pathCCDB.value, ".", metadata, bdt.timestampCCDB.value, false, bdt.localPath.value);
141-
142-
if (retrieveSuccessMCharm) {
143-
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
144-
} else {
145-
LOG(fatal) << "Error encountered while fetching/loading the MCharm model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
146-
}
147-
} else {
148-
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
149-
}
150-
151137
histos.add("SelectionQA/hDCAXicDaughters", "hDCAXicDaughters; DCA between Xic daughters (#mum)", kTH1D, {axisDcaDaughters});
152138
histos.add("SelectionQA/hDCAXiccDaughters", "hDCAXiccDaughters; DCA between Xicc daughters (#mum)", kTH1D, {axisDcaDaughters});
153139
histos.add("SelectionQA/hDCAxyXi", "hDCAxyXi; Xi DCAxy to PV (#mum)", kTH1D, {axisDCA});
@@ -249,6 +235,24 @@ struct alice3multicharm {
249235
histos.add("h3dXicc", "h3dXicc; Xicc pT (GeV/#it(c)); Xicc #eta; Xicc mass (GeV/#it(c)^{2})", kTH3D, {axisPt, axisEta, axisXiccMass});
250236

251237
if (bdt.enableML) {
238+
ccdb->setURL(bdt.ccdbUrl.value);
239+
if (bdt.loadModelsFromCCDB) {
240+
ccdbApi.init(bdt.ccdbUrl);
241+
LOG(info) << "Fetching model for timestamp: " << bdt.timestampCCDB.value;
242+
bool retrieveSuccessMCharm = ccdbApi.retrieveBlob(bdt.pathCCDB.value, ".", metadata, bdt.timestampCCDB.value, false, bdt.localPath.value);
243+
244+
if (retrieveSuccessMCharm) {
245+
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
246+
} else {
247+
LOG(fatal) << "Error encountered while fetching/loading the MCharm model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
248+
}
249+
} else {
250+
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
251+
}
252+
253+
histos.add("hBDTScore", "hBDTScore", kTH1D, {axisBDTScore});
254+
histos.add("hBDTScoreVsXiccMass", "hBDTScoreVsXiccMass", kTH2D, {axisXiccMass, axisBDTScore});
255+
histos.add("hBDTScoreVsXiccPt", "hBDTScoreVsXiccPt", kTH2D, {axisXiccMass, axisPt});
252256
for (const auto& score : bdt.requiredScores.value) {
253257
histPath = std::format("MLQA/RequiredBDTScore_{}/", static_cast<int>(score * 100));
254258
histPointers.insert({histPath + "hDCAXicDaughters", histos.add((histPath + "hDCAXicDaughters").c_str(), "hDCAXicDaughters", {kTH1D, {{axisDcaDaughters}}})});
@@ -292,7 +296,6 @@ struct alice3multicharm {
292296
void genericProcessXicc(TMCharmCands xiccCands)
293297
{
294298
for (const auto& xiccCand : xiccCands) {
295-
296299
if (bdt.enableML) {
297300
std::vector<float> inputFeatures{
298301
xiccCand.xicDauDCA(),
@@ -318,6 +321,10 @@ struct alice3multicharm {
318321
float* probabilityMCharm = bdtMCharm.evalModel(inputFeatures);
319322
float bdtScore = probabilityMCharm[1];
320323

324+
histos.fill(HIST("hBDTScore"), bdtScore);
325+
histos.fill(HIST("hBDTScoreVsXiccMass"), xiccCand.xiccMass(), bdtScore);
326+
histos.fill(HIST("hBDTScoreVsXiccPt"), xiccCand.xiccPt(), bdtScore);
327+
321328
for (const auto& requiredScore : bdt.requiredScores.value) {
322329
if (bdtScore > requiredScore) {
323330
histPath = std::format("MLQA/RequiredBDTScore_{}/", static_cast<int>(requiredScore * 100));

Common/Core/PID/TPCPIDResponse.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,18 @@ class Response
7474
/// Gets the expected resolution of the track
7575
template <typename CollisionType, typename TrackType>
7676
float GetExpectedSigma(const CollisionType& collision, const TrackType& trk, const o2::track::PID::ID id) const;
77+
/// Gets the expected resolution of the track with multTPC explicitly provided
78+
template <typename TrackType>
79+
float GetExpectedSigmaAtMultiplicity(const long multTPC, const TrackType& trk, const o2::track::PID::ID id) const;
7780
/// Gets the number of sigmas with respect the expected value
7881
template <typename CollisionType, typename TrackType>
7982
float GetNumberOfSigma(const CollisionType& collision, const TrackType& trk, const o2::track::PID::ID id) const;
8083
// Number of sigmas with respect to expected for MC, defining a tune-on-data signal value
8184
template <typename CollisionType, typename TrackType>
8285
float GetNumberOfSigmaMCTuned(const CollisionType& collision, const TrackType& trk, const o2::track::PID::ID id, float mcTunedTPCSignal) const;
86+
// Number of sigmas with respect to expected for MC, defining a tune-on-data signal value, explicit multTPC
87+
template <typename TrackType>
88+
float GetNumberOfSigmaMCTunedAtMultiplicity(const long multTPC, const TrackType& trk, const o2::track::PID::ID id, float mcTunedTPCSignal) const;
8389
/// Gets the deviation to the expected signal
8490
template <typename TrackType>
8591
float GetSignalDelta(const TrackType& trk, const o2::track::PID::ID id) const;
@@ -116,6 +122,14 @@ inline float Response::GetExpectedSignal(const TrackType& track, const o2::track
116122
/// Gets the expected resolution of the measurement
117123
template <typename CollisionType, typename TrackType>
118124
inline float Response::GetExpectedSigma(const CollisionType& collision, const TrackType& track, const o2::track::PID::ID id) const
125+
{
126+
// use multTPC (legacy behaviour) if multTPC not provided
127+
return Response::GetExpectedSigmaAtMultiplicity(collision.multTPC(), track, id);
128+
}
129+
130+
/// Gets the expected resolution of the measurement
131+
template <typename TrackType>
132+
inline float Response::GetExpectedSigmaAtMultiplicity(const long multTPC, const TrackType& track, const o2::track::PID::ID id) const
119133
{
120134
if (!track.hasTPC()) {
121135
return -999.f;
@@ -133,7 +147,7 @@ inline float Response::GetExpectedSigma(const CollisionType& collision, const Tr
133147
const double dEdx = o2::tpc::BetheBlochAleph(static_cast<float>(bg), mBetheBlochParams[0], mBetheBlochParams[1], mBetheBlochParams[2], mBetheBlochParams[3], mBetheBlochParams[4]) * std::pow(static_cast<float>(o2::track::pid_constants::sCharges[id]), mChargeFactor);
134148
const double relReso = GetRelativeResolutiondEdx(p, mass, o2::track::pid_constants::sCharges[id], mResolutionParams[3]);
135149

136-
const std::vector<double> values{1.f / dEdx, track.tgl(), std::sqrt(ncl), relReso, track.signed1Pt(), collision.multTPC() / mMultNormalization};
150+
const std::vector<double> values{1.f / dEdx, track.tgl(), std::sqrt(ncl), relReso, track.signed1Pt(), multTPC / mMultNormalization};
137151

138152
const float reso = sqrt(pow(mResolutionParams[0], 2) * values[0] + pow(mResolutionParams[1], 2) * (values[2] * mResolutionParams[5]) * pow(values[0] / sqrt(1 + pow(values[1], 2)), mResolutionParams[2]) + values[2] * pow(values[3], 2) + pow(mResolutionParams[4] * values[4], 2) + pow(values[5] * mResolutionParams[6], 2) + pow(values[5] * (values[0] / sqrt(1 + pow(values[1], 2))) * mResolutionParams[7], 2)) * dEdx * mMIP;
139153
reso >= 0.f ? resolution = reso : resolution = -999.f;
@@ -160,7 +174,13 @@ inline float Response::GetNumberOfSigma(const CollisionType& collision, const Tr
160174
template <typename CollisionType, typename TrackType>
161175
inline float Response::GetNumberOfSigmaMCTuned(const CollisionType& collision, const TrackType& trk, const o2::track::PID::ID id, float mcTunedTPCSignal) const
162176
{
163-
if (GetExpectedSigma(collision, trk, id) < 0.) {
177+
return Response::GetNumberOfSigmaMCTunedAtMultiplicity(collision.multTPC(), trk, id, mcTunedTPCSignal);
178+
}
179+
180+
template <typename TrackType>
181+
inline float Response::GetNumberOfSigmaMCTunedAtMultiplicity(const long multTPC, const TrackType& trk, const o2::track::PID::ID id, float mcTunedTPCSignal) const
182+
{
183+
if (GetExpectedSigmaAtMultiplicity(multTPC, trk, id) < 0.) {
164184
return -999.f;
165185
}
166186
if (GetExpectedSignal(trk, id) < 0.) {
@@ -169,7 +189,7 @@ inline float Response::GetNumberOfSigmaMCTuned(const CollisionType& collision, c
169189
if (!trk.hasTPC()) {
170190
return -999.f;
171191
}
172-
return ((mcTunedTPCSignal - GetExpectedSignal(trk, id)) / GetExpectedSigma(collision, trk, id));
192+
return ((mcTunedTPCSignal - GetExpectedSignal(trk, id)) / GetExpectedSigmaAtMultiplicity(multTPC, trk, id));
173193
}
174194

175195
/// Gets the deviation between the actual signal and the expected signal

Common/TableProducer/PID/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ o2physics_add_dpl_workflow(pid-tof-full
4343

4444
# TPC
4545

46+
o2physics_add_dpl_workflow(pid-tpc-service
47+
SOURCES pidTPCService.cxx
48+
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::MLCore O2Physics::AnalysisCCDB
49+
COMPONENT_NAME Analysis)
50+
4651
o2physics_add_dpl_workflow(pid-tpc-base
4752
SOURCES pidTPCBase.cxx
4853
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore O2Physics::AnalysisCCDB

Common/TableProducer/PID/pidTPC.cxx

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ struct tpcPid {
150150
Configurable<int> useNetworkHe{"useNetworkHe", 1, {"Switch for applying neural network on the helium3 mass hypothesis (if network enabled) (set to 0 to disable)"}};
151151
Configurable<int> useNetworkAl{"useNetworkAl", 1, {"Switch for applying neural network on the alpha mass hypothesis (if network enabled) (set to 0 to disable)"}};
152152
Configurable<float> networkBetaGammaCutoff{"networkBetaGammaCutoff", 0.45, {"Lower value of beta-gamma to override the NN application"}};
153+
Configurable<float> networkInputBatchedMode{"networkInputBatchedMode", -1, {"-1: Takes all tracks, >0: Takes networkInputBatchedMode number of tracks at once"}};
153154

154155
// Parametrization configuration
155156
bool useCCDBParam = false;
157+
std::vector<float> track_properties;
156158

157159
void init(o2::framework::InitContext& initContext)
158160
{
@@ -298,8 +300,6 @@ struct tpcPid {
298300
std::vector<float> createNetworkPrediction(C const& collisions, T const& tracks, B const& bcs, const size_t size)
299301
{
300302

301-
std::vector<float> network_prediction;
302-
303303
auto start_network_total = std::chrono::high_resolution_clock::now();
304304
if (autofetchNetworks) {
305305
const auto& bc = bcs.begin();
@@ -345,20 +345,24 @@ struct tpcPid {
345345
// Defining some network parameters
346346
int input_dimensions = network.getNumInputNodes();
347347
int output_dimensions = network.getNumOutputNodes();
348-
const uint64_t track_prop_size = input_dimensions * size;
349348
const uint64_t prediction_size = output_dimensions * size;
349+
const uint8_t numSpecies = 9;
350+
const uint64_t total_eval_size = size * numSpecies; // 9 species
350351

351-
network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
352352
const float nNclNormalization = response->GetNClNormalization();
353353
float duration_network = 0;
354354

355-
std::vector<float> track_properties(track_prop_size);
356-
uint64_t counter_track_props = 0;
357-
int loop_counter = 0;
355+
uint64_t counter_track_props = 0, exec_counter = 0, in_batch_counter = 0, total_input_count = 0;
356+
uint64_t track_prop_size = networkInputBatchedMode.value;
357+
if (networkInputBatchedMode.value <= 0) {
358+
track_prop_size = size; // If the networkInputBatchedMode is not set, we use all tracks at once
359+
}
360+
track_properties.resize(track_prop_size * input_dimensions); // If the networkInputBatchedMode is set, we use the number of tracks specified in the config
361+
std::vector<float> network_prediction(prediction_size * numSpecies); // For each mass hypotheses
358362

359363
// Filling a std::vector<float> to be evaluated by the network
360364
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
361-
for (int i = 0; i < 9; i++) { // Loop over particle number for which network correction is used
365+
for (int species = 0; species < numSpecies; species++) { // Loop over particle number for which network correction is used
362366
for (auto const& trk : tracks) {
363367
if (!trk.hasTPC()) {
364368
continue;
@@ -368,30 +372,38 @@ struct tpcPid {
368372
continue;
369373
}
370374
}
371-
track_properties[counter_track_props] = trk.tpcInnerParam();
375+
376+
if ((in_batch_counter == track_prop_size) || (total_input_count == total_eval_size)) { // If the batch size is reached, reset the counter
377+
int32_t fill_shift = (exec_counter * track_prop_size - ((total_input_count == total_eval_size) ? (total_input_count % track_prop_size) : 0)) * output_dimensions;
378+
auto start_network_eval = std::chrono::high_resolution_clock::now();
379+
float* output_network = network.evalModel(track_properties);
380+
auto stop_network_eval = std::chrono::high_resolution_clock::now();
381+
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
382+
383+
for (uint64_t i = 0; i < (in_batch_counter * output_dimensions); i += output_dimensions) {
384+
for (int j = 0; j < output_dimensions; j++) {
385+
network_prediction[i + j + fill_shift] = output_network[i + j];
386+
}
387+
}
388+
counter_track_props = 0;
389+
in_batch_counter = 0;
390+
exec_counter++;
391+
}
392+
393+
// LOG(info) << "counter_tracks_props: " << counter_track_props << "; in_batch_counter: " << in_batch_counter << "; total_input_count: " << total_input_count << "; exec_counter: " << exec_counter << "; track_prop_size: " << track_prop_size << "; size: " << size << "; track_properties.size(): " << track_properties.size();
394+
track_properties[counter_track_props] = trk.tpcInnerParam(); // (tracks.asArrowTable()->GetColumn<float>("tpcInnerParam")).GetData();
372395
track_properties[counter_track_props + 1] = trk.tgl();
373396
track_properties[counter_track_props + 2] = trk.signed1Pt();
374-
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[i];
397+
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[species];
375398
track_properties[counter_track_props + 4] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).multTPC() / 11000. : 1.;
376399
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
377400
if (input_dimensions == 7 && networkVersion == "2") {
378401
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
379402
}
380403
counter_track_props += input_dimensions;
404+
in_batch_counter++;
405+
total_input_count++;
381406
}
382-
383-
auto start_network_eval = std::chrono::high_resolution_clock::now();
384-
float* output_network = network.evalModel(track_properties);
385-
auto stop_network_eval = std::chrono::high_resolution_clock::now();
386-
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
387-
for (uint64_t i = 0; i < prediction_size; i += output_dimensions) {
388-
for (int j = 0; j < output_dimensions; j++) {
389-
network_prediction[i + j + prediction_size * loop_counter] = output_network[i + j];
390-
}
391-
}
392-
393-
counter_track_props = 0;
394-
loop_counter += 1;
395407
}
396408
track_properties.clear();
397409

@@ -435,6 +447,11 @@ struct tpcPid {
435447
}
436448

437449
float nSigma = -999.f;
450+
int multTPC = 0;
451+
if (trk.has_collision()) {
452+
auto collision = collisions.rawIteratorAt(trk.collisionId());
453+
multTPC = collision.multTPC();
454+
}
438455
float bg = trk.tpcInnerParam() / o2::track::pid_constants::sMasses[pid]; // estimated beta-gamma for network cutoff
439456
if (useNetworkCorrection && speciesNetworkFlags[pid] && trk.has_collision() && bg > networkBetaGammaCutoff) {
440457

@@ -457,7 +474,7 @@ struct tpcPid {
457474
LOGF(fatal, "Network output-dimensions incompatible!");
458475
}
459476
} else {
460-
nSigma = response->GetNumberOfSigmaMCTuned(collisions.iteratorAt(trk.collisionId()), trk, pid, tpcSignal);
477+
nSigma = response->GetNumberOfSigmaMCTunedAtMultiplicity(multTPC, trk, pid, tpcSignal);
461478
}
462479
if (flagFull)
463480
tableFull(expSigma, nSigma);

0 commit comments

Comments
 (0)