Skip to content

Commit 7e553e7

Browse files
authored
Merge branch 'master' into centralitystudy1
2 parents 1abd4a4 + febce25 commit 7e553e7

File tree

58 files changed

+2913
-1782
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2913
-1782
lines changed

ALICE3/Tasks/alice3-multicharm.cxx

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ struct alice3multicharm {
8484
std::string prefix = "bdt"; // JSON group name
8585
Configurable<std::string> ccdbUrl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
8686
Configurable<std::string> localPath{"localPath", "MCharm_BDTModel.onnx", "(std::string) Path to the local .onnx file."};
87-
Configurable<std::string> pathCCDB{"btdPathCCDB", "Users/j/jekarlss/MLModels2", "Path on CCDB"};
88-
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
87+
Configurable<std::string> pathCCDB{"btdPathCCDB", "Users/j/jekarlss/MLModels", "Path on CCDB"};
88+
Configurable<int64_t> timestampCCDB{"timestampCCDB", 1695750420200, "timestamp of the ONNX file for ML model used to query in CCDB. Please use 1695750420200"};
8989
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
9090
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
9191
Configurable<bool> enableML{"enableML", false, "Enables bdt model"};
9292
Configurable<std::vector<float>> requiredScores{"requiredScores", {0.5, 0.75, 0.85, 0.9, 0.95, 0.99}, "Vector of different scores to try"};
9393
} bdt;
9494

9595
ConfigurableAxis axisEta{"axisEta", {80, -4.0f, +4.0f}, "#eta"};
96-
ConfigurableAxis axisXicMass{"axisXicMass", {200, 2.368f, 2.568f}, "XiC Inv Mass (GeV/c^{2})"};
96+
ConfigurableAxis axisXicMass{"axisXicMass", {200, 2.368f, 2.568f}, "Xic Inv Mass (GeV/c^{2})"};
9797
ConfigurableAxis axisXiccMass{"axisXiccMass", {200, 3.521f, 3.721f}, "Xicc Inv Mass (GeV/c^{2})"};
9898
ConfigurableAxis axisDCA{"axisDCA", {400, 0, 400}, "DCA (#mum)"};
9999
ConfigurableAxis axisRadiusLarge{"axisRadiusLarge", {1000, 0, 20}, "Decay radius (cm)"};
@@ -102,6 +102,7 @@ struct alice3multicharm {
102102
ConfigurableAxis axisNSigma{"axisNSigma", {21, -10, 10}, "nsigma"};
103103
ConfigurableAxis axisDecayLength{"axisDecayLength", {2000, 0, 2000}, "Decay lenght (#mum)"};
104104
ConfigurableAxis axisDcaDaughters{"axisDcaDaughters", {200, 0, 100}, "DCA (mum)"};
105+
ConfigurableAxis axisBDTScore{"axisBDTScore", {100, 0, 1}, "BDT Score"};
105106
ConfigurableAxis axisPt{"axisPt", {VARIABLE_WIDTH, 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.0f, 2.2f, 2.4f, 2.6f, 2.8f, 3.0f, 3.2f, 3.4f, 3.6f, 3.8f, 4.0f, 4.4f, 4.8f, 5.2f, 5.6f, 6.0f, 6.5f, 7.0f, 7.5f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 17.0f, 19.0f, 21.0f, 23.0f, 25.0f, 30.0f, 35.0f, 40.0f, 50.0f}, "pt axis for QA histograms"};
106107

107108
Configurable<float> xiMinDCAxy{"xiMinDCAxy", -1, "[0] in |DCAxy| > [0]+[1]/pT"};
@@ -133,21 +134,6 @@ struct alice3multicharm {
133134

134135
void init(InitContext&)
135136
{
136-
ccdb->setURL(bdt.ccdbUrl.value);
137-
if (bdt.loadModelsFromCCDB) {
138-
ccdbApi.init(bdt.ccdbUrl);
139-
LOG(info) << "Fetching model for timestamp: " << bdt.timestampCCDB.value;
140-
bool retrieveSuccessMCharm = ccdbApi.retrieveBlob(bdt.pathCCDB.value, ".", metadata, bdt.timestampCCDB.value, false, bdt.localPath.value);
141-
142-
if (retrieveSuccessMCharm) {
143-
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
144-
} else {
145-
LOG(fatal) << "Error encountered while fetching/loading the MCharm model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
146-
}
147-
} else {
148-
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
149-
}
150-
151137
histos.add("SelectionQA/hDCAXicDaughters", "hDCAXicDaughters; DCA between Xic daughters (#mum)", kTH1D, {axisDcaDaughters});
152138
histos.add("SelectionQA/hDCAXiccDaughters", "hDCAXiccDaughters; DCA between Xicc daughters (#mum)", kTH1D, {axisDcaDaughters});
153139
histos.add("SelectionQA/hDCAxyXi", "hDCAxyXi; Xi DCAxy to PV (#mum)", kTH1D, {axisDCA});
@@ -249,6 +235,24 @@ struct alice3multicharm {
249235
histos.add("h3dXicc", "h3dXicc; Xicc pT (GeV/#it(c)); Xicc #eta; Xicc mass (GeV/#it(c)^{2})", kTH3D, {axisPt, axisEta, axisXiccMass});
250236

251237
if (bdt.enableML) {
238+
ccdb->setURL(bdt.ccdbUrl.value);
239+
if (bdt.loadModelsFromCCDB) {
240+
ccdbApi.init(bdt.ccdbUrl);
241+
LOG(info) << "Fetching model for timestamp: " << bdt.timestampCCDB.value;
242+
bool retrieveSuccessMCharm = ccdbApi.retrieveBlob(bdt.pathCCDB.value, ".", metadata, bdt.timestampCCDB.value, false, bdt.localPath.value);
243+
244+
if (retrieveSuccessMCharm) {
245+
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
246+
} else {
247+
LOG(fatal) << "Error encountered while fetching/loading the MCharm model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
248+
}
249+
} else {
250+
bdtMCharm.initModel(bdt.localPath.value, bdt.enableOptimizations.value);
251+
}
252+
253+
histos.add("hBDTScore", "hBDTScore", kTH1D, {axisBDTScore});
254+
histos.add("hBDTScoreVsXiccMass", "hBDTScoreVsXiccMass", kTH2D, {axisXiccMass, axisBDTScore});
255+
histos.add("hBDTScoreVsXiccPt", "hBDTScoreVsXiccPt", kTH2D, {axisXiccMass, axisPt});
252256
for (const auto& score : bdt.requiredScores.value) {
253257
histPath = std::format("MLQA/RequiredBDTScore_{}/", static_cast<int>(score * 100));
254258
histPointers.insert({histPath + "hDCAXicDaughters", histos.add((histPath + "hDCAXicDaughters").c_str(), "hDCAXicDaughters", {kTH1D, {{axisDcaDaughters}}})});
@@ -292,7 +296,6 @@ struct alice3multicharm {
292296
void genericProcessXicc(TMCharmCands xiccCands)
293297
{
294298
for (const auto& xiccCand : xiccCands) {
295-
296299
if (bdt.enableML) {
297300
std::vector<float> inputFeatures{
298301
xiccCand.xicDauDCA(),
@@ -318,6 +321,10 @@ struct alice3multicharm {
318321
float* probabilityMCharm = bdtMCharm.evalModel(inputFeatures);
319322
float bdtScore = probabilityMCharm[1];
320323

324+
histos.fill(HIST("hBDTScore"), bdtScore);
325+
histos.fill(HIST("hBDTScoreVsXiccMass"), xiccCand.xiccMass(), bdtScore);
326+
histos.fill(HIST("hBDTScoreVsXiccPt"), xiccCand.xiccPt(), bdtScore);
327+
321328
for (const auto& requiredScore : bdt.requiredScores.value) {
322329
if (bdtScore > requiredScore) {
323330
histPath = std::format("MLQA/RequiredBDTScore_{}/", static_cast<int>(requiredScore * 100));

Common/TableProducer/PID/pidTPC.cxx

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ struct tpcPid {
150150
Configurable<int> useNetworkHe{"useNetworkHe", 1, {"Switch for applying neural network on the helium3 mass hypothesis (if network enabled) (set to 0 to disable)"}};
151151
Configurable<int> useNetworkAl{"useNetworkAl", 1, {"Switch for applying neural network on the alpha mass hypothesis (if network enabled) (set to 0 to disable)"}};
152152
Configurable<float> networkBetaGammaCutoff{"networkBetaGammaCutoff", 0.45, {"Lower value of beta-gamma to override the NN application"}};
153+
Configurable<float> networkInputBatchedMode{"networkInputBatchedMode", -1, {"-1: Takes all tracks, >0: Takes networkInputBatchedMode number of tracks at once"}};
153154

154155
// Parametrization configuration
155156
bool useCCDBParam = false;
157+
std::vector<float> track_properties;
156158

157159
void init(o2::framework::InitContext& initContext)
158160
{
@@ -298,8 +300,6 @@ struct tpcPid {
298300
std::vector<float> createNetworkPrediction(C const& collisions, T const& tracks, B const& bcs, const size_t size)
299301
{
300302

301-
std::vector<float> network_prediction;
302-
303303
auto start_network_total = std::chrono::high_resolution_clock::now();
304304
if (autofetchNetworks) {
305305
const auto& bc = bcs.begin();
@@ -345,20 +345,24 @@ struct tpcPid {
345345
// Defining some network parameters
346346
int input_dimensions = network.getNumInputNodes();
347347
int output_dimensions = network.getNumOutputNodes();
348-
const uint64_t track_prop_size = input_dimensions * size;
349348
const uint64_t prediction_size = output_dimensions * size;
349+
const uint8_t numSpecies = 9;
350+
const uint64_t total_eval_size = size * numSpecies; // 9 species
350351

351-
network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
352352
const float nNclNormalization = response->GetNClNormalization();
353353
float duration_network = 0;
354354

355-
std::vector<float> track_properties(track_prop_size);
356-
uint64_t counter_track_props = 0;
357-
int loop_counter = 0;
355+
uint64_t counter_track_props = 0, exec_counter = 0, in_batch_counter = 0, total_input_count = 0;
356+
uint64_t track_prop_size = networkInputBatchedMode.value;
357+
if (networkInputBatchedMode.value <= 0) {
358+
track_prop_size = size; // If the networkInputBatchedMode is not set, we use all tracks at once
359+
}
360+
track_properties.resize(track_prop_size * input_dimensions); // If the networkInputBatchedMode is set, we use the number of tracks specified in the config
361+
std::vector<float> network_prediction(prediction_size * numSpecies); // For each mass hypotheses
358362

359363
// Filling a std::vector<float> to be evaluated by the network
360364
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
361-
for (int i = 0; i < 9; i++) { // Loop over particle number for which network correction is used
365+
for (int species = 0; species < numSpecies; species++) { // Loop over particle number for which network correction is used
362366
for (auto const& trk : tracks) {
363367
if (!trk.hasTPC()) {
364368
continue;
@@ -368,30 +372,38 @@ struct tpcPid {
368372
continue;
369373
}
370374
}
371-
track_properties[counter_track_props] = trk.tpcInnerParam();
375+
376+
if ((in_batch_counter == track_prop_size) || (total_input_count == total_eval_size)) { // If the batch size is reached, reset the counter
377+
int32_t fill_shift = (exec_counter * track_prop_size - ((total_input_count == total_eval_size) ? (total_input_count % track_prop_size) : 0)) * output_dimensions;
378+
auto start_network_eval = std::chrono::high_resolution_clock::now();
379+
float* output_network = network.evalModel(track_properties);
380+
auto stop_network_eval = std::chrono::high_resolution_clock::now();
381+
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
382+
383+
for (uint64_t i = 0; i < (in_batch_counter * output_dimensions); i += output_dimensions) {
384+
for (int j = 0; j < output_dimensions; j++) {
385+
network_prediction[i + j + fill_shift] = output_network[i + j];
386+
}
387+
}
388+
counter_track_props = 0;
389+
in_batch_counter = 0;
390+
exec_counter++;
391+
}
392+
393+
// LOG(info) << "counter_tracks_props: " << counter_track_props << "; in_batch_counter: " << in_batch_counter << "; total_input_count: " << total_input_count << "; exec_counter: " << exec_counter << "; track_prop_size: " << track_prop_size << "; size: " << size << "; track_properties.size(): " << track_properties.size();
394+
track_properties[counter_track_props] = trk.tpcInnerParam(); // (tracks.asArrowTable()->GetColumn<float>("tpcInnerParam")).GetData();
372395
track_properties[counter_track_props + 1] = trk.tgl();
373396
track_properties[counter_track_props + 2] = trk.signed1Pt();
374-
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[i];
397+
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[species];
375398
track_properties[counter_track_props + 4] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).multTPC() / 11000. : 1.;
376399
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
377400
if (input_dimensions == 7 && networkVersion == "2") {
378401
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
379402
}
380403
counter_track_props += input_dimensions;
404+
in_batch_counter++;
405+
total_input_count++;
381406
}
382-
383-
auto start_network_eval = std::chrono::high_resolution_clock::now();
384-
float* output_network = network.evalModel(track_properties);
385-
auto stop_network_eval = std::chrono::high_resolution_clock::now();
386-
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
387-
for (uint64_t i = 0; i < prediction_size; i += output_dimensions) {
388-
for (int j = 0; j < output_dimensions; j++) {
389-
network_prediction[i + j + prediction_size * loop_counter] = output_network[i + j];
390-
}
391-
}
392-
393-
counter_track_props = 0;
394-
loop_counter += 1;
395407
}
396408
track_properties.clear();
397409

Common/TableProducer/trackPropagationTester.cxx

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,28 @@
2323
//
2424
//===============================================================
2525

26-
#include "Framework/AnalysisDataModel.h"
27-
#include "Framework/AnalysisTask.h"
28-
#include "Framework/runDataProcessing.h"
29-
#include "Framework/RunningWorkflowInfo.h"
30-
#include "Common/DataModel/TrackSelectionTables.h"
3126
#include "Common/Core/trackUtilities.h"
32-
#include "ReconstructionDataFormats/DCA.h"
33-
#include "DetectorsBase/Propagator.h"
34-
#include "DetectorsBase/GeometryManager.h"
35-
#include "CommonUtils/NameConf.h"
27+
#include "Common/DataModel/TrackSelectionTables.h"
28+
#include "Common/Tools/StandardCCDBLoader.h"
29+
#include "Common/Tools/TrackPropagationModule.h"
30+
#include "Common/Tools/TrackTuner.h"
31+
32+
#include "CCDB/BasicCCDBManager.h"
3633
#include "CCDB/CcdbApi.h"
34+
#include "CommonConstants/GeomConstants.h"
35+
#include "CommonUtils/NameConf.h"
36+
#include "DataFormatsCalibration/MeanVertexObject.h"
3737
#include "DataFormatsParameters/GRPMagField.h"
38-
#include "CCDB/BasicCCDBManager.h"
38+
#include "DetectorsBase/GeometryManager.h"
39+
#include "DetectorsBase/Propagator.h"
40+
#include "Framework/AnalysisDataModel.h"
41+
#include "Framework/AnalysisTask.h"
3942
#include "Framework/HistogramRegistry.h"
40-
#include "DataFormatsCalibration/MeanVertexObject.h"
41-
#include "CommonConstants/GeomConstants.h"
42-
#include "Common/Tools/TrackPropagationModule.h"
43-
#include "Common/Tools/StandardCCDBLoader.h"
43+
#include "Framework/RunningWorkflowInfo.h"
44+
#include "Framework/runDataProcessing.h"
45+
#include "ReconstructionDataFormats/DCA.h"
46+
47+
#include <string>
4448

4549
// The Run 3 AO2D stores the tracks at the point of innermost update. For a track with ITS this is the innermost (or second innermost)
4650
// ITS layer. For a track without ITS, this is the TPC inner wall or for loopers in the TPC even a radius beyond that.
@@ -59,6 +63,9 @@ struct TrackPropagationTester {
5963
o2::common::TrackPropagationProducts trackPropagationProducts;
6064
o2::common::TrackPropagationConfigurables trackPropagationConfigurables;
6165

66+
// the track tuner object -> needs to be here as it inherits from ConfigurableGroup (+ has its own copy of ccdbApi)
67+
TrackTuner trackTunerObj;
68+
6269
// CCDB boilerplate declarations
6370
o2::framework::Configurable<std::string> ccdburl{"ccdburl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
6471
Service<o2::ccdb::BasicCCDBManager> ccdb;
@@ -76,22 +83,22 @@ struct TrackPropagationTester {
7683
ccdb->setURL(ccdburl.value);
7784

7885
// task-specific
79-
trackPropagation.init(trackPropagationConfigurables, registry, initContext);
86+
trackPropagation.init(trackPropagationConfigurables, trackTunerObj, registry, initContext);
8087
}
8188

8289
void processReal(aod::Collisions const& collisions, soa::Join<aod::StoredTracksIU, aod::TracksCovIU, aod::TracksExtra> const& tracks, aod::Collisions const&, aod::BCs const& bcs)
8390
{
8491
// task-specific
8592
ccdbLoader.initCCDBfromBCs(standardCCDBLoaderConfigurables, ccdb, bcs);
86-
trackPropagation.fillTrackTables<false>(trackPropagationConfigurables, ccdbLoader, collisions, tracks, trackPropagationProducts, registry);
93+
trackPropagation.fillTrackTables<false>(trackPropagationConfigurables, trackTunerObj, ccdbLoader, collisions, tracks, trackPropagationProducts, registry);
8794
}
8895
PROCESS_SWITCH(TrackPropagationTester, processReal, "Process Real Data", true);
8996

9097
// -----------------------
9198
void processMc(aod::Collisions const& collisions, soa::Join<aod::StoredTracksIU, aod::McTrackLabels, aod::TracksCovIU, aod::TracksExtra> const& tracks, aod::McParticles const&, aod::Collisions const&, aod::BCs const& bcs)
9299
{
93100
ccdbLoader.initCCDBfromBCs(standardCCDBLoaderConfigurables, ccdb, bcs);
94-
trackPropagation.fillTrackTables<false>(trackPropagationConfigurables, ccdbLoader, collisions, tracks, trackPropagationProducts, registry);
101+
trackPropagation.fillTrackTables<false>(trackPropagationConfigurables, trackTunerObj, ccdbLoader, collisions, tracks, trackPropagationProducts, registry);
95102
}
96103
PROCESS_SWITCH(TrackPropagationTester, processMc, "Process Monte Carlo", false);
97104
};

0 commit comments

Comments
 (0)