Skip to content

Commit f061710

Browse files
[Common] Batched evaluation for ONNX models to reduce memory consumption (#11839)
Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent 4db5971 commit f061710

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

Common/TableProducer/PID/pidTPC.cxx

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ struct tpcPid {
150150
Configurable<int> useNetworkHe{"useNetworkHe", 1, {"Switch for applying neural network on the helium3 mass hypothesis (if network enabled) (set to 0 to disable)"}};
151151
Configurable<int> useNetworkAl{"useNetworkAl", 1, {"Switch for applying neural network on the alpha mass hypothesis (if network enabled) (set to 0 to disable)"}};
152152
Configurable<float> networkBetaGammaCutoff{"networkBetaGammaCutoff", 0.45, {"Lower value of beta-gamma to override the NN application"}};
153+
Configurable<float> networkInputBatchedMode{"networkInputBatchedMode", -1, {"-1: Takes all tracks, >0: Takes networkInputBatchedMode number of tracks at once"}};
153154

154155
// Parametrization configuration
155156
bool useCCDBParam = false;
157+
std::vector<float> track_properties;
156158

157159
void init(o2::framework::InitContext& initContext)
158160
{
@@ -298,8 +300,6 @@ struct tpcPid {
298300
std::vector<float> createNetworkPrediction(C const& collisions, T const& tracks, B const& bcs, const size_t size)
299301
{
300302

301-
std::vector<float> network_prediction;
302-
303303
auto start_network_total = std::chrono::high_resolution_clock::now();
304304
if (autofetchNetworks) {
305305
const auto& bc = bcs.begin();
@@ -345,20 +345,24 @@ struct tpcPid {
345345
// Defining some network parameters
346346
int input_dimensions = network.getNumInputNodes();
347347
int output_dimensions = network.getNumOutputNodes();
348-
const uint64_t track_prop_size = input_dimensions * size;
349348
const uint64_t prediction_size = output_dimensions * size;
349+
const uint8_t numSpecies = 9;
350+
const uint64_t total_eval_size = size * numSpecies; // 9 species
350351

351-
network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
352352
const float nNclNormalization = response->GetNClNormalization();
353353
float duration_network = 0;
354354

355-
std::vector<float> track_properties(track_prop_size);
356-
uint64_t counter_track_props = 0;
357-
int loop_counter = 0;
355+
uint64_t counter_track_props = 0, exec_counter = 0, in_batch_counter = 0, total_input_count = 0;
356+
uint64_t track_prop_size = networkInputBatchedMode.value;
357+
if (networkInputBatchedMode.value <= 0) {
358+
track_prop_size = size; // If the networkInputBatchedMode is not set, we use all tracks at once
359+
}
360+
track_properties.resize(track_prop_size * input_dimensions); // If the networkInputBatchedMode is set, we use the number of tracks specified in the config
361+
std::vector<float> network_prediction(prediction_size * numSpecies); // For each mass hypotheses
358362

359363
// Filling a std::vector<float> to be evaluated by the network
360364
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
361-
for (int i = 0; i < 9; i++) { // Loop over particle number for which network correction is used
365+
for (int species = 0; species < numSpecies; species++) { // Loop over particle number for which network correction is used
362366
for (auto const& trk : tracks) {
363367
if (!trk.hasTPC()) {
364368
continue;
@@ -368,30 +372,38 @@ struct tpcPid {
368372
continue;
369373
}
370374
}
371-
track_properties[counter_track_props] = trk.tpcInnerParam();
375+
376+
if ((in_batch_counter == track_prop_size) || (total_input_count == total_eval_size)) { // If the batch size is reached, reset the counter
377+
int32_t fill_shift = (exec_counter * track_prop_size - ((total_input_count == total_eval_size) ? (total_input_count % track_prop_size) : 0)) * output_dimensions;
378+
auto start_network_eval = std::chrono::high_resolution_clock::now();
379+
float* output_network = network.evalModel(track_properties);
380+
auto stop_network_eval = std::chrono::high_resolution_clock::now();
381+
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
382+
383+
for (uint64_t i = 0; i < (in_batch_counter * output_dimensions); i += output_dimensions) {
384+
for (int j = 0; j < output_dimensions; j++) {
385+
network_prediction[i + j + fill_shift] = output_network[i + j];
386+
}
387+
}
388+
counter_track_props = 0;
389+
in_batch_counter = 0;
390+
exec_counter++;
391+
}
392+
393+
// LOG(info) << "counter_tracks_props: " << counter_track_props << "; in_batch_counter: " << in_batch_counter << "; total_input_count: " << total_input_count << "; exec_counter: " << exec_counter << "; track_prop_size: " << track_prop_size << "; size: " << size << "; track_properties.size(): " << track_properties.size();
394+
track_properties[counter_track_props] = trk.tpcInnerParam(); // (tracks.asArrowTable()->GetColumn<float>("tpcInnerParam")).GetData();
372395
track_properties[counter_track_props + 1] = trk.tgl();
373396
track_properties[counter_track_props + 2] = trk.signed1Pt();
374-
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[i];
397+
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[species];
375398
track_properties[counter_track_props + 4] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).multTPC() / 11000. : 1.;
376399
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
377400
if (input_dimensions == 7 && networkVersion == "2") {
378401
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
379402
}
380403
counter_track_props += input_dimensions;
404+
in_batch_counter++;
405+
total_input_count++;
381406
}
382-
383-
auto start_network_eval = std::chrono::high_resolution_clock::now();
384-
float* output_network = network.evalModel(track_properties);
385-
auto stop_network_eval = std::chrono::high_resolution_clock::now();
386-
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
387-
for (uint64_t i = 0; i < prediction_size; i += output_dimensions) {
388-
for (int j = 0; j < output_dimensions; j++) {
389-
network_prediction[i + j + prediction_size * loop_counter] = output_network[i + j];
390-
}
391-
}
392-
393-
counter_track_props = 0;
394-
loop_counter += 1;
395407
}
396408
track_properties.clear();
397409

0 commit comments

Comments
 (0)