@@ -73,6 +73,9 @@ struct cascademlselection {
7373 o2::ml::OnnxModel mlModelOmegaMinus;
7474 o2::ml::OnnxModel mlModelOmegaPlus;
7575
76+ // Custom grouping
77+ std::vector<std::vector<int >> cascadesGrouped;
78+
7679 std::map<std::string, std::string> metadata;
7780
7881 Produces<aod::CascXiMLScores> xiMLSelections; // optionally aggregate information from ML output for posterior analysis (derived data)
@@ -261,30 +264,52 @@ struct cascademlselection {
261264 }
262265 }
263266
264- void processDerivedData (soa::Join<aod::StraCollisions, aod::StraStamps>::iterator const & collision , CascDerivedDatas const & cascades)
267+ void processDerivedData (soa::Join<aod::StraCollisions, aod::StraStamps> const & collisions , CascDerivedDatas const & cascades)
265268 {
266- initCCDB (collision);
269+ // Custom grouping
270+ cascadesGrouped.clear ();
271+ cascadesGrouped.resize (collisions.size ());
272+
273+ for (const auto & cascade : cascades) {
274+ cascadesGrouped[cascade.straCollisionId ()].push_back (cascade.globalIndex ());
275+ }
267276
268- histos.fill (HIST (" hEventVertexZ" ), collision.posZ ());
269- for (auto & casc : cascades) {
270- nCandidates++;
271- if (nCandidates % 50000 == 0 ) {
272- LOG (info) << " Candidates processed: " << nCandidates;
277+ for (const auto & collision : collisions) {
278+ initCCDB (collision);
279+
280+ histos.fill (HIST (" hEventVertexZ" ), collision.posZ ());
281+ for (std::size_t i = 0 ; i < cascadesGrouped[collision.globalIndex ()].size (); i++) {
282+ auto casc = cascades.rawIteratorAt (cascadesGrouped[collision.globalIndex ()][i]);
283+ nCandidates++;
284+ if (nCandidates % 50000 == 0 ) {
285+ LOG (info) << " Candidates processed: " << nCandidates;
286+ }
287+ processCandidate (casc);
273288 }
274- processCandidate (casc);
275289 }
276290 }
277- void processStandardData (aod::Collision const & collision , CascOriginalDatas const & cascades)
291+ void processStandardData (aod::Collisions const & collisions , CascOriginalDatas const & cascades)
278292 {
279- initCCDB (collision);
293+ // Custom grouping
294+ cascadesGrouped.clear ();
295+ cascadesGrouped.resize (collisions.size ());
280296
281- histos.fill (HIST (" hEventVertexZ" ), collision.posZ ());
282- for (auto & casc : cascades) {
283- nCandidates++;
284- if (nCandidates % 50000 == 0 ) {
285- LOG (info) << " Candidates processed: " << nCandidates;
297+ for (const auto & cascade : cascades) {
298+ cascadesGrouped[cascade.collisionId ()].push_back (cascade.globalIndex ());
299+ }
300+
301+ for (const auto & collision : collisions) {
302+ initCCDB (collision);
303+
304+ histos.fill (HIST (" hEventVertexZ" ), collision.posZ ());
305+ for (std::size_t i = 0 ; i < cascadesGrouped[collision.globalIndex ()].size (); i++) {
306+ auto casc = cascades.rawIteratorAt (cascadesGrouped[collision.globalIndex ()][i]);
307+ nCandidates++;
308+ if (nCandidates % 50000 == 0 ) {
309+ LOG (info) << " Candidates processed: " << nCandidates;
310+ }
311+ processCandidate (casc);
286312 }
287- processCandidate (casc);
288313 }
289314 }
290315
0 commit comments