Skip to content

Commit 4f83bca

Browse files
authored
Switch from vector to unordered_map for MC particles enumeration (#5799)
* Switch from vector to unordered_map for MC particles enumeration * Fix for MC particles marking and clean-up for typenames.
1 parent 814cb33 commit 4f83bca

File tree

2 files changed

+115
-73
lines changed

2 files changed

+115
-73
lines changed

Detectors/AOD/include/AODProducerWorkflow/AODProducerWorkflowSpec.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@
2323
#include "CCDB/BasicCCDBManager.h"
2424
#include "Steer/MCKinematicsReader.h"
2525
#include "SimulationDataFormat/MCCompLabel.h"
26+
2627
#include <string>
2728
#include <vector>
29+
#include <boost/tuple/tuple.hpp>
30+
#include <boost/unordered_map.hpp>
31+
#include <boost/functional/hash.hpp>
2832

2933
using namespace o2::framework;
3034

@@ -95,6 +99,30 @@ using MCParticlesTable = o2::soa::Table<o2::aod::mcparticle::McCollisionId,
9599
o2::aod::mcparticle::Vz,
96100
o2::aod::mcparticle::Vt>;
97101

102+
typedef boost::tuple<int, int, int> Triplet_t;
103+
104+
struct TripletHash : std::unary_function<Triplet_t, std::size_t> {
105+
std::size_t operator()(Triplet_t const& e) const
106+
{
107+
std::size_t seed = 0;
108+
boost::hash_combine(seed, e.get<0>());
109+
boost::hash_combine(seed, e.get<1>());
110+
boost::hash_combine(seed, e.get<2>());
111+
return seed;
112+
}
113+
};
114+
115+
struct TripletEqualTo : std::binary_function<Triplet_t, Triplet_t, bool> {
116+
bool operator()(Triplet_t const& x, Triplet_t const& y) const
117+
{
118+
return (x.get<0>() == y.get<0>() &&
119+
x.get<1>() == y.get<1>() &&
120+
x.get<2>() == y.get<2>());
121+
}
122+
};
123+
124+
typedef boost::unordered_map<Triplet_t, int, TripletHash, TripletEqualTo> TripletsMap_t;
125+
98126
class AODProducerWorkflowDPL : public Task
99127
{
100128
public:
@@ -161,7 +189,7 @@ class AODProducerWorkflowDPL : public Task
161189
template <typename MCParticlesCursorType>
162190
void fillMCParticlesTable(o2::steer::MCKinematicsReader& mcReader, const MCParticlesCursorType& mcParticlesCursor,
163191
gsl::span<const o2::MCCompLabel>& mcTruthITS, gsl::span<const o2::MCCompLabel>& mcTruthTPC,
164-
std::vector<std::vector<std::vector<int>>>& toStore);
192+
TripletsMap_t& toStore);
165193

166194
void writeTableToFile(TFile* outfile, std::shared_ptr<arrow::Table>& table, const std::string& tableName, uint64_t tfNumber);
167195
};

Detectors/AOD/src/AODProducerWorkflowSpec.cxx

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <CCDB/BasicCCDBManager.h>
1818
#include "CommonDataFormat/InteractionRecord.h"
1919
#include "Framework/AnalysisDataModel.h"
20-
#include "Framework/AnalysisHelpers.h"
2120
#include "Framework/ConfigParamRegistry.h"
2221
#include "Framework/DataTypes.h"
2322
#include "Framework/InputRecordWalker.h"
@@ -214,101 +213,124 @@ void AODProducerWorkflowDPL::fillTracksTable(const TTracks& tracks, std::vector<
214213
template <typename MCParticlesCursorType>
215214
void AODProducerWorkflowDPL::fillMCParticlesTable(o2::steer::MCKinematicsReader& mcReader, const MCParticlesCursorType& mcParticlesCursor,
216215
gsl::span<const o2::MCCompLabel>& mcTruthITS, gsl::span<const o2::MCCompLabel>& mcTruthTPC,
217-
std::vector<std::vector<std::vector<int>>>& toStore)
216+
TripletsMap_t& toStore)
218217
{
219-
// mark reconstructed MC tracks to store them into the table
220-
if (mFillTracksITS) {
221-
for (auto& mcTruth : mcTruthITS) {
222-
if (!mcTruth.isValid()) {
223-
continue;
224-
}
225-
int source = mcTruth.getSourceID();
226-
int event = mcTruth.getEventID();
227-
int track = mcTruth.getTrackID();
228-
toStore[source][event][track] = 1;
218+
// mark reconstructed MC particles to store them into the table
219+
for (auto& mcTruth : mcTruthITS) {
220+
if (!mcTruth.isValid()) {
221+
continue;
229222
}
223+
int source = mcTruth.getSourceID();
224+
int event = mcTruth.getEventID();
225+
int particle = mcTruth.getTrackID();
226+
toStore[Triplet_t(source, event, particle)] = 1;
230227
}
231-
if (mFillTracksTPC) {
232-
for (auto& mcTruth : mcTruthTPC) {
233-
if (!mcTruth.isValid()) {
234-
continue;
235-
}
236-
int source = mcTruth.getSourceID();
237-
int event = mcTruth.getEventID();
238-
int track = mcTruth.getTrackID();
239-
toStore[source][event][track] = 1;
228+
for (auto& mcTruth : mcTruthTPC) {
229+
if (!mcTruth.isValid()) {
230+
continue;
240231
}
232+
int source = mcTruth.getSourceID();
233+
int event = mcTruth.getEventID();
234+
int particle = mcTruth.getTrackID();
235+
toStore[Triplet_t(source, event, particle)] = 1;
241236
}
242237
int tableIndex = 1;
243238
for (int source = 0; source < mcReader.getNSources(); source++) {
244239
for (int event = 0; event < mcReader.getNEvents(source); event++) {
245240
std::vector<MCTrack> const& mcParticles = mcReader.getTracks(source, event);
246241
// mark tracks to be stored per event
247-
// loop over stack of MC tracks from end to beginning: daughters are stored after mothers
242+
// loop over stack of MC particles from end to beginning: daughters are stored after mothers
248243
if (mRecoOnly) {
249-
for (int track = mcParticles.size() - 1; track <= 0; track--) {
250-
int mother0 = mcParticles[track].getMotherTrackId();
251-
int mother1 = mcParticles[track].getSecondMotherTrackId();
252-
if (mother0 == -1 || mother1 == -1) {
253-
toStore[source][event][track] = 1;
244+
for (int particle = mcParticles.size() - 1; particle >= 0; particle--) {
245+
int mother0 = mcParticles[particle].getMotherTrackId();
246+
if (mother0 == -1) {
247+
toStore[Triplet_t(source, event, particle)] = 1;
254248
}
255-
if (toStore[source][event][track] == 0) {
249+
if (toStore.find(Triplet_t(source, event, particle)) == toStore.end()) {
256250
continue;
257251
}
258252
if (mother0 != -1) {
259-
toStore[source][event][mother0] = 1;
253+
toStore[Triplet_t(source, event, mother0)] = 1;
260254
}
255+
int mother1 = mcParticles[particle].getSecondMotherTrackId();
261256
if (mother1 != -1) {
262-
toStore[source][event][mother1] = 1;
257+
toStore[Triplet_t(source, particle, mother1)] = 1;
263258
}
264-
int daughter0 = mcParticles[track].getFirstDaughterTrackId();
265-
int daughterL = mcParticles[track].getLastDaughterTrackId();
259+
int daughter0 = mcParticles[particle].getFirstDaughterTrackId();
266260
if (daughter0 != -1) {
267-
toStore[source][event][daughter0] = 1;
261+
toStore[Triplet_t(source, event, daughter0)] = 1;
268262
}
263+
int daughterL = mcParticles[particle].getLastDaughterTrackId();
269264
if (daughterL != -1) {
270-
toStore[source][event][daughterL] = 1;
265+
toStore[Triplet_t(source, event, daughterL)] = 1;
266+
}
267+
}
268+
// enumerate reconstructed mc particles and their relatives to get mother/daughter relations
269+
for (int particle = 0; particle < mcParticles.size(); particle++) {
270+
auto mapItem = toStore.find(Triplet_t(source, event, particle));
271+
if (mapItem != toStore.end()) {
272+
mapItem->second = tableIndex;
273+
tableIndex++;
271274
}
272275
}
273276
}
274-
// enumerate tracks to get mother/daughter relations
275-
for (int track = 0; track < toStore[source][event].size(); track++) {
276-
if (!toStore[source][event][track] && mRecoOnly) {
277-
continue;
277+
// if all mc particles are stored, all mc particles will be enumerated
278+
if (!mRecoOnly) {
279+
for (int particle = 0; particle < mcParticles.size(); particle++) {
280+
toStore[Triplet_t(source, event, particle)] = tableIndex;
281+
tableIndex++;
278282
}
279-
toStore[source][event][track] = tableIndex;
280-
tableIndex++;
281283
}
282284
// fill survived mc tracks into the table
283-
for (int track = 0; track < mcParticles.size(); track++) {
284-
if (!toStore[source][event][track] && mRecoOnly) {
285+
for (int particle = 0; particle < mcParticles.size(); particle++) {
286+
if (toStore.find(Triplet_t(source, event, particle)) == toStore.end()) {
285287
continue;
286288
}
287289
int statusCode = 0;
288290
uint8_t flags = 0;
289291
float weight = 0.f;
290-
int mother0 = mcParticles[track].getMotherTrackId() != -1 ? toStore[source][event][mcParticles[track].getMotherTrackId()] - 1 : -1;
291-
int mother1 = mcParticles[track].getSecondMotherTrackId() != -1 ? toStore[source][event][mcParticles[track].getSecondMotherTrackId()] - 1 : -1;
292-
int daughter0 = mcParticles[track].getFirstDaughterTrackId() != -1 ? toStore[source][event][mcParticles[track].getFirstDaughterTrackId()] - 1 : -1;
293-
int daughterL = mcParticles[track].getLastDaughterTrackId() != -1 ? toStore[source][event][mcParticles[track].getLastDaughterTrackId()] - 1 : -1;
292+
int mcMother0 = mcParticles[particle].getMotherTrackId();
293+
auto item = toStore.find(Triplet_t(source, event, mcMother0));
294+
int mother0 = -1;
295+
if (item != toStore.end()) {
296+
mother0 = item->second;
297+
}
298+
int mcMother1 = mcParticles[particle].getSecondMotherTrackId();
299+
int mother1 = -1;
300+
item = toStore.find(Triplet_t(source, event, mcMother1));
301+
if (item != toStore.end()) {
302+
mother1 = item->second;
303+
}
304+
int mcDaughter0 = mcParticles[particle].getFirstDaughterTrackId();
305+
int daughter0 = -1;
306+
item = toStore.find(Triplet_t(source, event, mcDaughter0));
307+
if (item != toStore.end()) {
308+
daughter0 = item->second;
309+
}
310+
int mcDaughterL = mcParticles[particle].getLastDaughterTrackId();
311+
int daughterL = -1;
312+
item = toStore.find(Triplet_t(source, event, mcDaughterL));
313+
if (item != toStore.end()) {
314+
daughterL = item->second;
315+
}
294316
mcParticlesCursor(0,
295317
event,
296-
mcParticles[track].GetPdgCode(),
318+
mcParticles[particle].GetPdgCode(),
297319
statusCode,
298320
flags,
299321
mother0,
300322
mother1,
301323
daughter0,
302324
daughterL,
303325
truncateFloatFraction(weight, mMcParticleW),
304-
truncateFloatFraction((float)mcParticles[track].Px(), mMcParticleMom),
305-
truncateFloatFraction((float)mcParticles[track].Py(), mMcParticleMom),
306-
truncateFloatFraction((float)mcParticles[track].Pz(), mMcParticleMom),
307-
truncateFloatFraction((float)mcParticles[track].GetEnergy(), mMcParticleMom),
308-
truncateFloatFraction((float)mcParticles[track].Vx(), mMcParticlePos),
309-
truncateFloatFraction((float)mcParticles[track].Vy(), mMcParticlePos),
310-
truncateFloatFraction((float)mcParticles[track].Vz(), mMcParticlePos),
311-
truncateFloatFraction((float)mcParticles[track].T(), mMcParticlePos));
326+
truncateFloatFraction((float)mcParticles[particle].Px(), mMcParticleMom),
327+
truncateFloatFraction((float)mcParticles[particle].Py(), mMcParticleMom),
328+
truncateFloatFraction((float)mcParticles[particle].Pz(), mMcParticleMom),
329+
truncateFloatFraction((float)mcParticles[particle].GetEnergy(), mMcParticleMom),
330+
truncateFloatFraction((float)mcParticles[particle].Vx(), mMcParticlePos),
331+
truncateFloatFraction((float)mcParticles[particle].Vy(), mMcParticlePos),
332+
truncateFloatFraction((float)mcParticles[particle].Vz(), mMcParticlePos),
333+
truncateFloatFraction((float)mcParticles[particle].T(), mMcParticlePos));
312334
}
313335
mcReader.releaseTracksForSourceAndEvent(source, event);
314336
}
@@ -715,15 +737,15 @@ void AODProducerWorkflowDPL::run(ProcessingContext& pc)
715737
uint64_t triggerMask = 1;
716738
std::sort(BCIDs.begin(), BCIDs.end());
717739
uint64_t prevBCid = BCIDs.back();
718-
for (auto& BCid : BCIDs) {
719-
if (BCid == prevBCid) {
740+
for (auto& id : BCIDs) {
741+
if (id == prevBCid) {
720742
continue;
721743
}
722744
bcCursor(0,
723745
runNumber,
724-
startBCofTF + minGlBC + BCid,
746+
startBCofTF + minGlBC + id,
725747
triggerMask);
726-
prevBCid = BCid;
748+
prevBCid = id;
727749
}
728750

729751
BCIDs.clear();
@@ -735,15 +757,7 @@ void AODProducerWorkflowDPL::run(ProcessingContext& pc)
735757
}
736758

737759
// filling mc particles table
738-
std::vector<std::vector<std::vector<int>>> toStore;
739-
for (int source = 0; source < mcReader.getNSources(); source++) {
740-
std::vector<std::vector<int>> vEvents;
741-
toStore.push_back(vEvents);
742-
for (int event = 0; event < mcReader.getNEvents(source); event++) {
743-
std::vector<int> vTracks(mcReader.getTracks(source, event).size(), 0);
744-
toStore[source].push_back(vTracks);
745-
}
746-
}
760+
TripletsMap_t toStore;
747761
fillMCParticlesTable(mcReader, mcParticlesCursor, tracksITSMCTruth, tracksTPCMCTruth, toStore);
748762
if (mIgnoreWriter) {
749763
std::shared_ptr<arrow::Table> tableMCParticles = mcParticlesBuilder.finalize();
@@ -771,7 +785,7 @@ void AODProducerWorkflowDPL::run(ProcessingContext& pc)
771785
// TODO: fill label mask
772786
labelMask = 0;
773787
if (mcTruthITS.isValid()) {
774-
labelID = toStore[mcTruthITS.getSourceID()][mcTruthITS.getEventID()][mcTruthITS.getTrackID()];
788+
labelID = toStore.at(Triplet_t(mcTruthITS.getSourceID(), mcTruthITS.getEventID(), mcTruthITS.getTrackID()));
775789
}
776790
if (mcTruthITS.isFake()) {
777791
labelMask |= (0x1 << 15);
@@ -792,7 +806,7 @@ void AODProducerWorkflowDPL::run(ProcessingContext& pc)
792806
// TODO: fill label mask
793807
labelMask = 0;
794808
if (mcTruthTPC.isValid()) {
795-
labelID = toStore[mcTruthTPC.getSourceID()][mcTruthTPC.getEventID()][mcTruthTPC.getTrackID()];
809+
labelID = toStore.at(Triplet_t(mcTruthTPC.getSourceID(), mcTruthTPC.getEventID(), mcTruthTPC.getTrackID()));
796810
}
797811
if (mcTruthTPC.isFake()) {
798812
labelMask |= (0x1 << 15);
@@ -819,8 +833,8 @@ void AODProducerWorkflowDPL::run(ProcessingContext& pc)
819833
// currently using label mask to indicate labelITS != labelTPC
820834
labelMask = 0;
821835
if (mcTruthITS.isValid() && mcTruthTPC.isValid()) {
822-
labelITS = toStore[mcTruthITS.getSourceID()][mcTruthITS.getEventID()][mcTruthITS.getTrackID()];
823-
labelTPC = toStore[mcTruthTPC.getSourceID()][mcTruthTPC.getEventID()][mcTruthTPC.getTrackID()];
836+
labelITS = toStore.at(Triplet_t(mcTruthITS.getSourceID(), mcTruthITS.getEventID(), mcTruthITS.getTrackID()));
837+
labelTPC = toStore.at(Triplet_t(mcTruthTPC.getSourceID(), mcTruthTPC.getEventID(), mcTruthTPC.getTrackID()));
824838
labelID = labelITS;
825839
}
826840
if (mcTruthITS.isFake() || mcTruthTPC.isFake()) {

0 commit comments

Comments
 (0)