Skip to content

Commit 6cecb46

Browse files
committed
Merge pull request #53 from mpuccio/its/trk/stag
Fix compilation of ALICE3 tracking with staggering
2 parents 8de4135 + 36eba63 commit 6cecb46

File tree

3 files changed

+101
-91
lines changed

3 files changed

+101
-91
lines changed

Detectors/Upgrades/ALICE3/TRK/reconstruction/include/TRKReconstruction/TimeFrame.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class GeometryTGeo;
3838

3939
/// TRK TimeFrame class that extends ITS TimeFrame functionality
4040
/// This allows for customization of tracking algorithms specific to the TRK detector
41-
template <int nLayers = 11>
42-
class TimeFrame : public o2::its::TimeFrame<nLayers>
41+
template <int NLayers = 11>
42+
class TimeFrame : public o2::its::TimeFrame<NLayers>
4343
{
4444
public:
4545
TimeFrame() = default;
@@ -50,8 +50,6 @@ class TimeFrame : public o2::its::TimeFrame<nLayers>
5050

5151
/// Process hits from TTree to initialize ROFs
5252
/// \param hitsTree Tree containing TRK hits
53-
/// \param mcHeaderTree Tree containing MC event headers
54-
/// \param nEvents Number of events to process
5553
/// \param gman TRK geometry manager instance
5654
/// \param config Configuration parameters for hit reconstruction
5755
int loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman, const nlohmann::json& config);
@@ -61,7 +59,8 @@ class TimeFrame : public o2::its::TimeFrame<nLayers>
6159
/// \param nRofs Number of ROFs (Read-Out Frames)
6260
/// \param nEvents Number of events to process
6361
/// \param inROFpileup Number of events per ROF
64-
void getPrimaryVerticesFromMC(TTree* mcHeaderTree, int nRofs, Long64_t nEvents, int inROFpileup);
62+
/// \param rofLength ROF length in BCs (must match what was used in loadROFsFromHitTree)
63+
void getPrimaryVerticesFromMC(TTree* mcHeaderTree, int nRofs, Long64_t nEvents, int inROFpileup, uint32_t rofLength = 198);
6564
};
6665

6766
} // namespace trk

Detectors/Upgrades/ALICE3/TRK/reconstruction/src/TimeFrame.cxx

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
#include <vector>
2424
#include <array>
2525

26+
using o2::its::clearResizeBoundedVector;
27+
2628
namespace o2::trk
2729
{
2830

29-
template <int nLayers>
30-
int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman, const nlohmann::json& config)
31+
template <int NLayers>
32+
int TimeFrame<NLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman, const nlohmann::json& config)
3133
{
3234
constexpr std::array<int, 2> startLayer{0, 3};
3335
const Long64_t nEvents = hitsTree->GetEntries();
@@ -39,23 +41,39 @@ int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman,
3941

4042
const int inROFpileup{config.contains("inROFpileup") ? config["inROFpileup"].get<int>() : 1};
4143

42-
// Calculate number of ROFs and initialize data structures
43-
this->mNrof = (nEvents + inROFpileup - 1) / inROFpileup;
44+
// Calculate number of ROFs
45+
const int nRofs = (nEvents + inROFpileup - 1) / inROFpileup;
46+
47+
// Set up ROF timing for all layers (no staggering in TRK simulation, all layers read out together)
48+
constexpr uint32_t rofLength = 198; // ROF length in BC
49+
o2::its::ROFOverlapTable<NLayers> overlapTable;
50+
for (int iLayer = 0; iLayer < NLayers; ++iLayer) {
51+
overlapTable.defineLayer(iLayer, nRofs, rofLength, 0, 0, 0);
52+
}
53+
overlapTable.init();
54+
this->setROFOverlapTable(overlapTable);
55+
56+
// Set up the vertex lookup table timing (pre-allocate, vertices will be filled later)
57+
o2::its::ROFVertexLookupTable<NLayers> vtxLookupTable;
58+
for (int iLayer = 0; iLayer < NLayers; ++iLayer) {
59+
vtxLookupTable.defineLayer(iLayer, nRofs, rofLength, 0, 0, 0);
60+
}
61+
vtxLookupTable.init(); // pre-allocate without vertices
62+
this->setROFVertexLookupTable(vtxLookupTable);
4463

4564
// Reset and prepare ROF data structures
46-
for (int iLayer{0}; iLayer < nLayers; ++iLayer) {
65+
for (int iLayer{0}; iLayer < NLayers; ++iLayer) {
4766
this->mMinR[iLayer] = std::numeric_limits<float>::max();
4867
this->mMaxR[iLayer] = std::numeric_limits<float>::lowest();
4968
this->mROFramesClusters[iLayer].clear();
50-
this->mROFramesClusters[iLayer].resize(this->mNrof + 1, 0);
69+
this->mROFramesClusters[iLayer].resize(nRofs + 1, 0);
5170
this->mUnsortedClusters[iLayer].clear();
5271
this->mTrackingFrameInfo[iLayer].clear();
5372
this->mClusterExternalIndices[iLayer].clear();
5473
}
5574

5675
// Pre-count hits to reserve memory efficiently
57-
int totalNHits{0};
58-
std::array<int, nLayers> clusterCountPerLayer{};
76+
std::array<int, NLayers> clusterCountPerLayer{};
5977
for (Long64_t iEvent = 0; iEvent < nEvents; ++iEvent) {
6078
hitsTree->GetEntry(iEvent);
6179
for (const auto& hit : *trkHit) {
@@ -64,35 +82,35 @@ int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman,
6482
}
6583
int subDetID = gman->getSubDetID(hit.GetDetectorID());
6684
const int layer = startLayer[subDetID] + gman->getLayer(hit.GetDetectorID());
67-
if (layer >= nLayers) {
85+
if (layer >= NLayers) {
6886
continue;
6987
}
7088
++clusterCountPerLayer[layer];
71-
totalNHits++;
7289
}
7390
}
7491

75-
// Reserve memory for all layers
76-
for (int iLayer{0}; iLayer < nLayers; ++iLayer) {
92+
// Reserve memory for all layers (mClusterSize is now per-layer)
93+
for (int iLayer{0}; iLayer < NLayers; ++iLayer) {
7794
this->mUnsortedClusters[iLayer].reserve(clusterCountPerLayer[iLayer]);
7895
this->mTrackingFrameInfo[iLayer].reserve(clusterCountPerLayer[iLayer]);
7996
this->mClusterExternalIndices[iLayer].reserve(clusterCountPerLayer[iLayer]);
97+
clearResizeBoundedVector(this->mClusterSize[iLayer], clusterCountPerLayer[iLayer], this->mMemoryPool.get());
8098
}
81-
clearResizeBoundedVector(this->mClusterSize, totalNHits, this->mMemoryPool.get());
8299

83100
std::array<float, 11> resolution{0.001, 0.001, 0.001, 0.001, 0.004, 0.004, 0.004, 0.004, 0.004, 0.004, 0.004};
84-
if (config["geometry"]["pitch"].size() == nLayers) {
85-
for (int iLayer{0}; iLayer < config["geometry"]["pitch"].size(); ++iLayer) {
101+
if (config["geometry"]["pitch"].size() == static_cast<size_t>(NLayers)) {
102+
for (size_t iLayer{0}; iLayer < config["geometry"]["pitch"].size(); ++iLayer) {
86103
LOGP(info, "Setting resolution for layer {} from config", iLayer);
87104
LOGP(info, "Layer {} pitch {} cm", iLayer, config["geometry"]["pitch"][iLayer].get<float>());
88105
resolution[iLayer] = config["geometry"]["pitch"][iLayer].get<float>() / std::sqrt(12.f);
89106
}
90107
}
91108
LOGP(info, "Number of active parts in VD: {}", gman->getNumberOfActivePartsVD());
92109

93-
int hitCounter{0};
94-
auto labels = new dataformats::MCTruthContainer<MCCompLabel>();
110+
// One shared MC label container for all layers
111+
auto* labels = new dataformats::MCTruthContainer<MCCompLabel>();
95112

113+
int hitCounter{0};
96114
int iRof{0}; // Current ROF index
97115
for (Long64_t iEvent = 0; iEvent < nEvents; ++iEvent) {
98116
hitsTree->GetEntry(iEvent);
@@ -108,7 +126,7 @@ int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman,
108126
o2::math_utils::Point3D<float> gloXYZ;
109127
o2::math_utils::Point3D<float> trkXYZ;
110128
float r{0.f};
111-
if (layer >= nLayers) {
129+
if (layer >= NLayers) {
112130
continue;
113131
}
114132
if (layer >= 3) {
@@ -139,11 +157,12 @@ int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman,
139157
std::array<float, 2>{trkXYZ.y(), trkXYZ.z()},
140158
std::array<float, 3>{resolution[layer] * resolution[layer], 0., resolution[layer] * resolution[layer]});
141159
/// Rotate to the global frame
142-
this->addClusterToLayer(layer, gloXYZ.x(), gloXYZ.y(), gloXYZ.z(), this->mUnsortedClusters[layer].size());
160+
const int clusterIdxInLayer = this->mUnsortedClusters[layer].size();
161+
this->addClusterToLayer(layer, gloXYZ.x(), gloXYZ.y(), gloXYZ.z(), clusterIdxInLayer);
143162
this->addClusterExternalIndexToLayer(layer, hitCounter);
144163
MCCompLabel label{hit.GetTrackID(), static_cast<int>(iEvent), 0};
145164
labels->addElement(hitCounter, label);
146-
this->mClusterSize[hitCounter] = 1; // For compatibility with cluster-based tracking, set cluster size to 1 for hits
165+
this->mClusterSize[layer][clusterIdxInLayer] = 1;
147166
hitCounter++;
148167
}
149168
trkHit->clear();
@@ -154,21 +173,23 @@ int TimeFrame<nLayers>::loadROFsFromHitTree(TTree* hitsTree, GeometryTGeo* gman,
154173
for (unsigned int iLayer{0}; iLayer < this->mUnsortedClusters.size(); ++iLayer) {
155174
this->mROFramesClusters[iLayer][iRof] = this->mUnsortedClusters[iLayer].size(); // effectively calculating an exclusive sum
156175
}
157-
// Update primary vertices ROF structure
158176
}
159-
this->mClusterLabels = labels;
160177
}
161-
return this->mNrof;
178+
179+
// Set the shared labels container for all layers
180+
for (int iLayer = 0; iLayer < NLayers; ++iLayer) {
181+
this->mClusterLabels[iLayer] = labels;
182+
}
183+
184+
return nRofs;
162185
}
163186

164-
template <int nLayers>
165-
void TimeFrame<nLayers>::getPrimaryVerticesFromMC(TTree* mcHeaderTree, int nRofs, Long64_t nEvents, int inROFpileup)
187+
template <int NLayers>
188+
void TimeFrame<NLayers>::getPrimaryVerticesFromMC(TTree* mcHeaderTree, int nRofs, Long64_t nEvents, int inROFpileup, uint32_t rofLength)
166189
{
167190
auto mcheader = new o2::dataformats::MCEventHeader;
168191
mcHeaderTree->SetBranchAddress("MCEventHeader.", &mcheader);
169192

170-
this->mROFramesPV.clear();
171-
this->mROFramesPV.resize(nRofs + 1, 0);
172193
this->mPrimaryVertices.clear();
173194

174195
int iRof{0};
@@ -178,14 +199,24 @@ void TimeFrame<nLayers>::getPrimaryVerticesFromMC(TTree* mcHeaderTree, int nRofs
178199
vertex.setXYZ(mcheader->GetX(), mcheader->GetY(), mcheader->GetZ());
179200
vertex.setNContributors(30);
180201
vertex.setChi2(0.f);
181-
LOGP(debug, "ROF {}: Added primary vertex at ({}, {}, {})", iRof, mcheader->GetX(), mcheader->GetY(), mcheader->GetZ());
182-
this->mPrimaryVertices.push_back(vertex);
202+
203+
// Set proper BC timestamp for vertex-ROF compatibility
204+
// The vertex timestamp is set to the center of its ROF with half-ROF as error
205+
const uint32_t rofCenter = static_cast<uint32_t>(rofLength * iRof + rofLength / 2);
206+
const uint16_t rofHalf = static_cast<uint16_t>(rofLength / 2);
207+
vertex.setTimeStamp({rofCenter, rofHalf});
208+
209+
LOGP(debug, "ROF {}: Added primary vertex at ({}, {}, {}) with BC timestamp [{}, +/-{}]",
210+
iRof, mcheader->GetX(), mcheader->GetY(), mcheader->GetZ(), rofCenter, rofHalf);
211+
this->addPrimaryVertex(vertex);
183212
if ((iEvent + 1) % inROFpileup == 0 || iEvent == nEvents - 1) {
184213
iRof++;
185-
this->mROFramesPV[iRof] = this->mPrimaryVertices.size(); // effectively calculating an exclusive sum
186214
}
187215
}
188216
this->mMultiplicityCutMask.resize(nRofs, true); /// all ROFs are valid with MC primary vertices.
217+
218+
// Update the vertex lookup table with the newly added vertices
219+
this->updateROFVertexLookupTable();
189220
}
190221

191222
// Explicit template instantiation for TRK with 11 layers

Detectors/Upgrades/ALICE3/TRK/workflow/src/TrackerSpec.cxx

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ namespace o2
3737
using namespace framework;
3838
namespace trk
3939
{
40-
using Vertex = o2::dataformats::Vertex<o2::dataformats::TimeStamp<int>>;
4140

4241
TrackerDPL::TrackerDPL(std::shared_ptr<o2::base::GRPGeomRequest> gr,
4342
bool isMC,
@@ -84,18 +83,12 @@ std::vector<o2::its::TrackingParameters> TrackerDPL::createTrackingParamsFromCon
8483
if (paramConfig.contains("NLayers")) {
8584
params.NLayers = paramConfig["NLayers"].get<int>();
8685
}
87-
if (paramConfig.contains("DeltaROF")) {
88-
params.DeltaROF = paramConfig["DeltaROF"].get<int>();
89-
}
9086
if (paramConfig.contains("ZBins")) {
9187
params.ZBins = paramConfig["ZBins"].get<int>();
9288
}
9389
if (paramConfig.contains("PhiBins")) {
9490
params.PhiBins = paramConfig["PhiBins"].get<int>();
9591
}
96-
if (paramConfig.contains("nROFsPerIterations")) {
97-
params.nROFsPerIterations = paramConfig["nROFsPerIterations"].get<int>();
98-
}
9992
if (paramConfig.contains("ClusterSharing")) {
10093
params.ClusterSharing = paramConfig["ClusterSharing"].get<int>();
10194
}
@@ -119,27 +112,21 @@ std::vector<o2::its::TrackingParameters> TrackerDPL::createTrackingParamsFromCon
119112
if (paramConfig.contains("TrackletMinPt")) {
120113
params.TrackletMinPt = paramConfig["TrackletMinPt"].get<float>();
121114
}
122-
if (paramConfig.contains("TrackletsPerClusterLimit")) {
123-
params.TrackletsPerClusterLimit = paramConfig["TrackletsPerClusterLimit"].get<float>();
124-
}
125115
if (paramConfig.contains("CellDeltaTanLambdaSigma")) {
126116
params.CellDeltaTanLambdaSigma = paramConfig["CellDeltaTanLambdaSigma"].get<float>();
127117
}
128-
if (paramConfig.contains("CellsPerClusterLimit")) {
129-
params.CellsPerClusterLimit = paramConfig["CellsPerClusterLimit"].get<float>();
130-
}
131118
if (paramConfig.contains("MaxChi2ClusterAttachment")) {
132119
params.MaxChi2ClusterAttachment = paramConfig["MaxChi2ClusterAttachment"].get<float>();
133120
}
134121
if (paramConfig.contains("MaxChi2NDF")) {
135122
params.MaxChi2NDF = paramConfig["MaxChi2NDF"].get<float>();
136123
}
137-
if (paramConfig.contains("TrackFollowerNSigmaCutZ")) {
138-
params.TrackFollowerNSigmaCutZ = paramConfig["TrackFollowerNSigmaCutZ"].get<float>();
139-
}
140-
if (paramConfig.contains("TrackFollowerNSigmaCutPhi")) {
141-
params.TrackFollowerNSigmaCutPhi = paramConfig["TrackFollowerNSigmaCutPhi"].get<float>();
142-
}
124+
// if (paramConfig.contains("TrackFollowerNSigmaCutZ")) {
125+
// params.TrackFollowerNSigmaCutZ = paramConfig["TrackFollowerNSigmaCutZ"].get<float>();
126+
// }
127+
// if (paramConfig.contains("TrackFollowerNSigmaCutPhi")) {
128+
// params.TrackFollowerNSigmaCutPhi = paramConfig["TrackFollowerNSigmaCutPhi"].get<float>();
129+
// }
143130

144131
// Parse boolean parameters
145132
if (paramConfig.contains("UseDiamond")) {
@@ -154,9 +141,9 @@ std::vector<o2::its::TrackingParameters> TrackerDPL::createTrackingParamsFromCon
154141
if (paramConfig.contains("ShiftRefToCluster")) {
155142
params.ShiftRefToCluster = paramConfig["ShiftRefToCluster"].get<bool>();
156143
}
157-
if (paramConfig.contains("FindShortTracks")) {
158-
params.FindShortTracks = paramConfig["FindShortTracks"].get<bool>();
159-
}
144+
// if (paramConfig.contains("FindShortTracks")) {
145+
// params.FindShortTracks = paramConfig["FindShortTracks"].get<bool>();
146+
// }
160147
if (paramConfig.contains("PerPrimaryVertexProcessing")) {
161148
params.PerPrimaryVertexProcessing = paramConfig["PerPrimaryVertexProcessing"].get<bool>();
162149
}
@@ -169,18 +156,18 @@ std::vector<o2::its::TrackingParameters> TrackerDPL::createTrackingParamsFromCon
169156
if (paramConfig.contains("FataliseUponFailure")) {
170157
params.FataliseUponFailure = paramConfig["FataliseUponFailure"].get<bool>();
171158
}
172-
if (paramConfig.contains("UseTrackFollower")) {
173-
params.UseTrackFollower = paramConfig["UseTrackFollower"].get<bool>();
174-
}
175-
if (paramConfig.contains("UseTrackFollowerTop")) {
176-
params.UseTrackFollowerTop = paramConfig["UseTrackFollowerTop"].get<bool>();
177-
}
178-
if (paramConfig.contains("UseTrackFollowerBot")) {
179-
params.UseTrackFollowerBot = paramConfig["UseTrackFollowerBot"].get<bool>();
180-
}
181-
if (paramConfig.contains("UseTrackFollowerMix")) {
182-
params.UseTrackFollowerMix = paramConfig["UseTrackFollowerMix"].get<bool>();
183-
}
159+
// if (paramConfig.contains("UseTrackFollower")) {
160+
// params.UseTrackFollower = paramConfig["UseTrackFollower"].get<bool>();
161+
// }
162+
// if (paramConfig.contains("UseTrackFollowerTop")) {
163+
// params.UseTrackFollowerTop = paramConfig["UseTrackFollowerTop"].get<bool>();
164+
// }
165+
// if (paramConfig.contains("UseTrackFollowerBot")) {
166+
// params.UseTrackFollowerBot = paramConfig["UseTrackFollowerBot"].get<bool>();
167+
// }
168+
// if (paramConfig.contains("UseTrackFollowerMix")) {
169+
// params.UseTrackFollowerMix = paramConfig["UseTrackFollowerMix"].get<bool>();
170+
// }
184171
if (paramConfig.contains("createArtefactLabels")) {
185172
params.createArtefactLabels = paramConfig["createArtefactLabels"].get<bool>();
186173
}
@@ -297,44 +284,37 @@ void TrackerDPL::run(ProcessingContext& pc)
297284
for (size_t iter{0}; iter < trackingParams.size(); ++iter) {
298285
LOGP(info, "{}", trackingParams[iter].asString());
299286
timeFrame.initialise(iter, trackingParams[iter], 11, false);
300-
itsTrackerTraits.computeLayerTracklets(iter, -1, -1);
287+
itsTrackerTraits.computeLayerTracklets(iter, -1);
301288
LOGP(info, "Number of tracklets in iteration {}: {}", iter, timeFrame.getNumberOfTracklets());
302289
itsTrackerTraits.computeLayerCells(iter);
303290
LOGP(info, "Number of cells in iteration {}: {}", iter, timeFrame.getNumberOfCells());
304291
itsTrackerTraits.findCellsNeighbours(iter);
305292
LOGP(info, "Number of cell neighbours in iteration {}: {}", iter, timeFrame.getNumberOfNeighbours());
306293
itsTrackerTraits.findRoads(iter);
307-
LOGP(info, "Number of roads in iteration {}: {}", iter, timeFrame.getNumberOfTracks());
308-
itsTrackerTraits.extendTracks(iter);
294+
LOGP(info, "Number of tracks in iteration {}: {}", iter, timeFrame.getNumberOfTracks());
309295
}
310296
const auto trackingLoopElapsedMs = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - trackingLoopStart).count();
311297
LOGP(info, "Tracking iterations block took {} ms", trackingLoopElapsedMs);
312298

313299
itsTracker.computeTracksMClabels();
314300

315-
// Stream tracks and their MC labels to the output
316-
// Collect all tracks and labels from all ROFs
317-
std::vector<o2::its::TrackITS> allTracks;
318-
std::vector<o2::MCCompLabel> allLabels;
301+
// Collect tracks and labels (flat vectors in the new interface)
302+
const auto& tracks = timeFrame.getTracks();
303+
const auto& labels = timeFrame.getTracksLabel();
319304

320-
int totalTracks = 0;
305+
// Copy to output vectors (TrackITSExt -> TrackITS slicing for output compatibility)
306+
std::vector<o2::its::TrackITS> allTracks(tracks.begin(), tracks.end());
307+
std::vector<o2::MCCompLabel> allLabels(labels.begin(), labels.end());
308+
309+
int totalTracks = allTracks.size();
321310
int goodTracks = 0;
322311
int fakeTracks = 0;
323312

324-
for (int iRof = 0; iRof < nRofs; ++iRof) {
325-
const auto& rofTracks = timeFrame.getTracks(iRof);
326-
const auto& rofLabels = timeFrame.getTracksLabel(iRof);
327-
328-
allTracks.insert(allTracks.end(), rofTracks.begin(), rofTracks.end());
329-
allLabels.insert(allLabels.end(), rofLabels.begin(), rofLabels.end());
330-
331-
totalTracks += rofTracks.size();
332-
for (const auto& label : rofLabels) {
333-
if (label.isFake()) {
334-
fakeTracks++;
335-
} else {
336-
goodTracks++;
337-
}
313+
for (const auto& label : allLabels) {
314+
if (label.isFake()) {
315+
fakeTracks++;
316+
} else {
317+
goodTracks++;
338318
}
339319
}
340320

0 commit comments

Comments
 (0)