You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: Common/TableProducer/PID/pidTPC.cxx
+35-23Lines changed: 35 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -150,9 +150,11 @@ struct tpcPid {
150
150
Configurable<int> useNetworkHe{"useNetworkHe", 1, {"Switch for applying neural network on the helium3 mass hypothesis (if network enabled) (set to 0 to disable)"}};
151
151
Configurable<int> useNetworkAl{"useNetworkAl", 1, {"Switch for applying neural network on the alpha mass hypothesis (if network enabled) (set to 0 to disable)"}};
152
152
Configurable<float> networkBetaGammaCutoff{"networkBetaGammaCutoff", 0.45, {"Lower value of beta-gamma to override the NN application"}};
153
+
Configurable<float> networkInputBatchedMode{"networkInputBatchedMode", -1, {"-1: Takes all tracks, >0: Takes networkInputBatchedMode number of tracks at once"}};
153
154
154
155
// Parametrization configuration
155
156
bool useCCDBParam = false;
157
+
std::vector<float> track_properties;
156
158
157
159
voidinit(o2::framework::InitContext& initContext)
158
160
{
@@ -298,8 +300,6 @@ struct tpcPid {
298
300
std::vector<float> createNetworkPrediction(C const& collisions, T const& tracks, B const& bcs, constsize_t size)
299
301
{
300
302
301
-
std::vector<float> network_prediction;
302
-
303
303
auto start_network_total = std::chrono::high_resolution_clock::now();
304
304
if (autofetchNetworks) {
305
305
constauto& bc = bcs.begin();
@@ -345,20 +345,24 @@ struct tpcPid {
345
345
// Defining some network parameters
346
346
int input_dimensions = network.getNumInputNodes();
347
347
int output_dimensions = network.getNumOutputNodes();
track_prop_size = size; // If the networkInputBatchedMode is not set, we use all tracks at once
359
+
}
360
+
track_properties.resize(track_prop_size * input_dimensions); // If the networkInputBatchedMode is set, we use the number of tracks specified in the config
361
+
std::vector<float> network_prediction(prediction_size * numSpecies); // For each mass hypotheses
358
362
359
363
// Filling a std::vector<float> to be evaluated by the network
360
364
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
361
-
for (inti = 0; i < 9; i++) { // Loop over particle number for which network correction is used
365
+
for (intspecies = 0; species < numSpecies; species++) { // Loop over particle number for which network correction is used
0 commit comments