Skip to content

Commit b942bba

Browse files
committed
ITS: allow sharing of arena in Tracker & Vertexer
Signed-off-by: Felix Schlepper <felix.schlepper@cern.ch>
1 parent b9b561d commit b942bba

File tree

9 files changed

+64
-39
lines changed

9 files changed

+64
-39
lines changed

Detectors/ITSMFT/ITS/tracking/include/ITStracking/Tracker.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include <utility>
2828
#include <sstream>
2929

30+
#include <oneapi/tbb/task_arena.h>
31+
3032
#include "ITStracking/Configuration.h"
3133
#include "CommonConstants/MathConstants.h"
3234
#include "ITStracking/Definitions.h"
@@ -73,8 +75,7 @@ class Tracker
7375
void setBz(float bz) { mTraits->setBz(bz); }
7476
void setCorrType(const o2::base::PropagatorImpl<float>::MatCorrType type) { mTraits->setCorrType(type); }
7577
bool isMatLUT() const { return mTraits->isMatLUT(); }
76-
void setNThreads(int n) { mTraits->setNThreads(n); }
77-
int getNThreads() const { return mTraits->getNThreads(); }
78+
void setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena) { mTraits->setNThreads(n, arena); }
7879
void printSummary() const;
7980

8081
private:

Detectors/ITSMFT/ITS/tracking/include/ITStracking/TrackerTraits.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class TrackerTraits
8080
void SetRecoChain(o2::gpu::GPUChainITS* chain) { mChain = chain; }
8181
void setSmoothing(bool v) { mApplySmoothing = v; }
8282
bool getSmoothing() const { return mApplySmoothing; }
83-
void setNThreads(int n);
84-
int getNThreads() const { return mNThreads; }
83+
void setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena);
84+
int getNThreads() { return mTaskArena->max_concurrency(); }
8585

8686
o2::gpu::GPUChainITS* getChain() const { return mChain; }
8787

@@ -94,10 +94,9 @@ class TrackerTraits
9494
track::TrackParCov buildTrackSeed(const Cluster& cluster1, const Cluster& cluster2, const TrackingFrameInfo& tf3);
9595
bool fitTrack(TrackITSExt& track, int start, int end, int step, float chi2clcut = o2::constants::math::VeryBig, float chi2ndfcut = o2::constants::math::VeryBig, float maxQoverPt = o2::constants::math::VeryBig, int nCl = 0);
9696

97-
int mNThreads = 1;
9897
bool mApplySmoothing = false;
9998
std::shared_ptr<BoundedMemoryResource> mMemoryPool;
100-
tbb::task_arena mTaskArena;
99+
std::shared_ptr<tbb::task_arena> mTaskArena;
101100

102101
protected:
103102
o2::base::PropagatorImpl<float>::MatCorrType mCorrType = o2::base::PropagatorImpl<float>::MatCorrType::USEMatCorrNONE;

Detectors/ITSMFT/ITS/tracking/include/ITStracking/TrackingInterface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "GPUO2Interface.h"
2929
#include "GPUChainITS.h"
3030

31+
#include <oneapi/tbb/task_arena.h>
32+
3133
namespace o2::its
3234
{
3335
class ITSTrackingInterface
@@ -97,6 +99,7 @@ class ITSTrackingInterface
9799
std::unique_ptr<Vertexer> mVertexer = nullptr;
98100
const o2::dataformats::MeanVertexObject* mMeanVertex;
99101
std::shared_ptr<BoundedMemoryResource> mMemoryPool;
102+
std::shared_ptr<tbb::task_arena> mTaskArena;
100103
};
101104

102105
} // namespace o2::its

Detectors/ITSMFT/ITS/tracking/include/ITStracking/Vertexer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include <iomanip>
2222
#include <array>
2323
#include <iosfwd>
24+
#include <memory>
25+
26+
#include <oneapi/tbb/task_arena.h>
2427

2528
#include "ITStracking/ROframe.h"
2629
#include "ITStracking/Constants.h"
@@ -90,6 +93,8 @@ class Vertexer
9093
const unsigned selectedN, const unsigned int vertexN, const float initT,
9194
const float trackletT, const float selecT, const float vertexT);
9295

96+
void setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena) { mTraits->setNThreads(n, arena); }
97+
9398
private:
9499
std::uint32_t mTimeFrameCounter = 0;
95100

Detectors/ITSMFT/ITS/tracking/include/ITStracking/VertexerTraits.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define O2_ITS_TRACKING_VERTEXER_TRAITS_H_
1818

1919
#include <array>
20+
#include <memory>
2021
#include <string>
2122
#include <vector>
2223

@@ -93,8 +94,8 @@ class VertexerTraits
9394
auto getVertexingParameters() const { return mVrtParams; }
9495
void setVertexingParameters(std::vector<VertexingParameters>& vertParams) { mVrtParams = vertParams; }
9596
void dumpVertexerTraits();
96-
void setNThreads(int n);
97-
int getNThreads() const { return mNThreads; }
97+
void setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena);
98+
int getNThreads() { return mTaskArena->max_concurrency(); }
9899
virtual bool isGPU() const noexcept { return false; }
99100
virtual const char* getName() const noexcept { return "CPU"; }
100101
virtual bool usesMemoryPool() const noexcept { return true; }
@@ -116,16 +117,14 @@ class VertexerTraits
116117
}
117118

118119
protected:
119-
int mNThreads = 1;
120-
121120
std::vector<VertexingParameters> mVrtParams;
122121
IndexTableUtils mIndexTableUtils;
123122

124123
// Frame related quantities
125124
TimeFrame7* mTimeFrame = nullptr; // observer ptr
126125
private:
127126
std::shared_ptr<BoundedMemoryResource> mMemoryPool;
128-
tbb::task_arena mTaskArena;
127+
std::shared_ptr<tbb::task_arena> mTaskArena;
129128
};
130129

131130
inline void VertexerTraits::initialise(const TrackingParameters& trackingParams, const int iteration)

Detectors/ITSMFT/ITS/tracking/src/Tracker.cxx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ void Tracker::getGlobalConfiguration()
342342
} else {
343343
mTraits->setCorrType(o2::base::PropagatorImpl<float>::MatCorrType::USEMatCorrLUT);
344344
}
345-
setNThreads(tc.nThreads);
346345
int nROFsPerIterations = tc.nROFsPerIterations > 0 ? tc.nROFsPerIterations : -1;
347346
if (tc.nOrbitsPerIterations > 0) {
348347
/// code to be used when the number of ROFs per orbit is known, this gets priority over the number of ROFs per iteration

Detectors/ITSMFT/ITS/tracking/src/TrackerTraits.cxx

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void TrackerTraits<nLayers>::computeLayerTracklets(const int iteration, int iROF
7171
int minRof = o2::gpu::CAMath::Max(startROF, rof0 - mTrkParams[iteration].DeltaROF);
7272
int maxRof = o2::gpu::CAMath::Min(endROF - 1, rof0 + mTrkParams[iteration].DeltaROF);
7373

74-
mTaskArena.execute([&] {
74+
mTaskArena->execute([&] {
7575
tbb::parallel_for(
7676
tbb::blocked_range<int>(0, mTrkParams[iteration].TrackletsPerRoad()),
7777
[&](const tbb::blocked_range<int>& Layers) {
@@ -200,7 +200,7 @@ void TrackerTraits<nLayers>::computeLayerTracklets(const int iteration, int iROF
200200
return a.firstClusterIndex == b.firstClusterIndex && a.secondClusterIndex == b.secondClusterIndex;
201201
};
202202

203-
mTaskArena.execute([&] {
203+
mTaskArena->execute([&] {
204204
tbb::parallel_for(
205205
tbb::blocked_range<int>(0, mTrkParams[iteration].CellsPerRoad()),
206206
[&](const tbb::blocked_range<int>& Layers) {
@@ -229,7 +229,7 @@ void TrackerTraits<nLayers>::computeLayerTracklets(const int iteration, int iROF
229229
/// Layer 0 is done outside the loop
230230
// in-place deduplication
231231
auto& trklt0 = mTimeFrame->getTracklets()[0];
232-
mTaskArena.execute([&] { tbb::parallel_sort(trklt0.begin(), trklt0.end(), sortTracklets); });
232+
mTaskArena->execute([&] { tbb::parallel_sort(trklt0.begin(), trklt0.end(), sortTracklets); });
233233
trklt0.erase(std::unique(trklt0.begin(), trklt0.end(), equalTracklets), trklt0.end());
234234
trklt0.shrink_to_fit();
235235

@@ -275,7 +275,7 @@ void TrackerTraits<nLayers>::computeLayerCells(const int iteration)
275275
}
276276
}
277277

278-
mTaskArena.execute([&] {
278+
mTaskArena->execute([&] {
279279
tbb::parallel_for(
280280
tbb::blocked_range<int>(0, mTrkParams[iteration].CellsPerRoad()),
281281
[&](const tbb::blocked_range<int>& Layers) {
@@ -496,7 +496,7 @@ void TrackerTraits<nLayers>::findCellsNeighbours(const int iteration)
496496
continue;
497497
}
498498

499-
mTaskArena.execute([&] {
499+
mTaskArena->execute([&] {
500500
int layerCellsNum{static_cast<int>(mTimeFrame->getCells()[iLayer].size())};
501501

502502
bounded_vector<int> perCellCount(layerCellsNum + 1, 0, mMemoryPool.get());
@@ -621,7 +621,7 @@ void TrackerTraits<nLayers>::processNeighbours(int iLayer, int iLevel, const bou
621621
int failed[5]{0, 0, 0, 0, 0}, attempts{0}, failedByMismatch{0};
622622
#endif
623623

624-
mTaskArena.execute([&] {
624+
mTaskArena->execute([&] {
625625
bounded_vector<int> perCellCount(currentCellSeed.size() + 1, 0, mMemoryPool.get());
626626
tbb::parallel_for(
627627
tbb::blocked_range<int>(0, (int)currentCellSeed.size()),
@@ -812,7 +812,7 @@ void TrackerTraits<nLayers>::findRoads(const int iteration)
812812
}
813813

814814
bounded_vector<TrackITSExt> tracks(mMemoryPool.get());
815-
mTaskArena.execute([&] {
815+
mTaskArena->execute([&] {
816816
bounded_vector<int> perSeedCount(trackSeeds.size() + 1, 0, mMemoryPool.get());
817817
tbb::parallel_for(
818818
tbb::blocked_range<int>(0, (int)trackSeeds.size()),
@@ -1258,17 +1258,19 @@ bool TrackerTraits<nLayers>::isMatLUT() const
12581258
}
12591259

12601260
template <int nLayers>
1261-
void TrackerTraits<nLayers>::setNThreads(int n)
1261+
void TrackerTraits<nLayers>::setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena)
12621262
{
1263-
if (mNThreads == n && mTaskArena.is_active()) {
1264-
return;
1265-
}
1266-
mNThreads = n > 0 ? n : 1;
12671263
#if defined(OPTIMISATION_OUTPUT) || defined(CA_DEBUG)
1268-
mNThreads = 1; // only works while serial
1264+
mTaskArena = std::make_shared<tbb::task_arena>(1);
1265+
#else
1266+
if (arena == nullptr) {
1267+
mTaskArena = std::make_shared<tbb::task_arena>(std::abs(n));
1268+
LOGP(info, "Setting tracker with {} threads.", n);
1269+
} else {
1270+
mTaskArena = arena;
1271+
LOGP(info, "Attaching tracker to calling thread's arena");
1272+
}
12691273
#endif
1270-
mTaskArena.initialize(mNThreads);
1271-
LOGP(info, "Setting tracker with {} threads.", mNThreads);
12721274
}
12731275

12741276
template class TrackerTraits<7>;

Detectors/ITSMFT/ITS/tracking/src/TrackingInterface.cxx

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "ITSReconstruction/FastMultEst.h"
1717

1818
#include "ITStracking/TrackingInterface.h"
19+
#include <oneapi/tbb/task_arena.h>
1920
#include <memory>
2021

2122
#include "DataFormatsITSMFT/ROFRecord.h"
@@ -148,6 +149,20 @@ void ITSTrackingInterface::initialise()
148149
}
149150
mTracker->setParameters(trackParams);
150151
mVertexer->setParameters(vertParams);
152+
if (trackConf.nThreads == vertConf.nThreads) {
153+
bool clamped{false};
154+
int nThreads = trackConf.nThreads;
155+
if (nThreads > 0) {
156+
const int hw = std::thread::hardware_concurrency();
157+
const int maxThreads = (hw == 0 ? 1 : hw);
158+
nThreads = std::clamp(nThreads, 1, maxThreads);
159+
clamped = trackConf.nThreads > maxThreads;
160+
}
161+
LOGP(info, "Tracker and Vertexer will share the task arena with {} thread(s){}", nThreads, (clamped) ? " (clamped)" : "");
162+
mTaskArena = std::make_shared<tbb::task_arena>(std::abs(nThreads));
163+
}
164+
mVertexer->setNThreads(vertConf.nThreads, mTaskArena);
165+
mTracker->setNThreads(trackConf.nThreads, mTaskArena);
151166
}
152167

153168
void ITSTrackingInterface::run(framework::ProcessingContext& pc)

Detectors/ITSMFT/ITS/tracking/src/VertexerTraits.cxx

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
///
1212

1313
#include <iostream>
14+
#include <memory>
1415
#include <string>
1516
#include <chrono>
1617

@@ -168,13 +169,12 @@ void VertexerTraits::updateVertexingParameters(const std::vector<VertexingParame
168169
par.phiSpan = static_cast<int>(std::ceil(mIndexTableUtils.getNphiBins() * par.phiCut / o2::constants::math::TwoPI));
169170
par.zSpan = static_cast<int>(std::ceil(par.zCut * mIndexTableUtils.getInverseZCoordinate(0)));
170171
}
171-
setNThreads(vrtPar[0].nThreads);
172172
}
173173

174174
// Main functions
175175
void VertexerTraits::computeTracklets(const int iteration)
176176
{
177-
mTaskArena.execute([&] {
177+
mTaskArena->execute([&] {
178178
tbb::parallel_for(
179179
tbb::blocked_range<short>(0, (short)mTimeFrame->getNrof()),
180180
[&](const tbb::blocked_range<short>& Rofs) {
@@ -220,7 +220,7 @@ void VertexerTraits::computeTracklets(const int iteration)
220220
mTimeFrame->getTracklets()[0].resize(mTimeFrame->getTotalTrackletsTF(0));
221221
mTimeFrame->getTracklets()[1].resize(mTimeFrame->getTotalTrackletsTF(1));
222222

223-
mTaskArena.execute([&] {
223+
mTaskArena->execute([&] {
224224
tbb::parallel_for(
225225
tbb::blocked_range<short>(0, (short)mTimeFrame->getNrof()),
226226
[&](const tbb::blocked_range<short>& Rofs) {
@@ -329,7 +329,7 @@ void VertexerTraits::computeTracklets(const int iteration)
329329

330330
void VertexerTraits::computeTrackletMatching(const int iteration)
331331
{
332-
mTaskArena.execute([&] {
332+
mTaskArena->execute([&] {
333333
tbb::parallel_for(
334334
tbb::blocked_range<short>(0, (short)mTimeFrame->getNrof()),
335335
[&](const tbb::blocked_range<short>& Rofs) {
@@ -687,15 +687,17 @@ void VertexerTraits::computeVerticesInRof(int rofId,
687687
verticesInRof.push_back(foundVertices);
688688
}
689689

690-
void VertexerTraits::setNThreads(int n)
690+
void VertexerTraits::setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena)
691691
{
692-
if (mNThreads == n && mTaskArena.is_active()) {
693-
return;
694-
}
695-
mNThreads = n > 0 ? n : 1;
696692
#if defined(VTX_DEBUG)
697-
mNThreads = 1;
693+
mTaskArena = std::make_shared<tbb::task_arena>(1);
694+
#else
695+
if (arena == nullptr) {
696+
mTaskArena = std::make_shared<tbb::task_arena>(std::abs(n));
697+
LOGP(info, "Setting seeding vertexer with {} threads.", n);
698+
} else {
699+
mTaskArena = arena;
700+
LOGP(info, "Attaching vertexer to calling thread's arena");
701+
}
698702
#endif
699-
mTaskArena.initialize(mNThreads);
700-
LOGP(info, "Setting seeding vertexer with {} threads.", mNThreads);
701703
}

0 commit comments

Comments
 (0)