Skip to content

Commit ace8b80

Browse files
committed
ITS: unify hybrid/cpu clusterToTracks interface
1 parent 4c6a087 commit ace8b80

File tree

9 files changed

+133
-204
lines changed

9 files changed

+133
-204
lines changed

Detectors/ITSMFT/ITS/tracking/GPU/ITStrackingGPU/TrackerTraitsGPU.h

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,36 @@
1717
#include "ITStracking/Definitions.h"
1818
#include "ITStracking/TrackerTraits.h"
1919
#include "ITStrackingGPU/TimeFrameGPU.h"
20+
#include "Framework/Logger.h"
2021

2122
namespace o2
2223
{
2324
namespace its
2425
{
2526

2627
template <int nLayers = 7>
27-
class TrackerTraitsGPU : public TrackerTraits
28+
class TrackerTraitsGPU final : public TrackerTraits
2829
{
2930
public:
3031
TrackerTraitsGPU() = default;
3132
~TrackerTraitsGPU() override = default;
3233

33-
// void computeLayerCells() final;
34-
void adoptTimeFrame(TimeFrame* tf) override;
35-
void initialiseTimeFrame(const int iteration) override;
36-
void computeLayerTracklets(const int iteration, int, int) final;
37-
void computeLayerCells(const int iteration) override;
38-
void setBz(float) override;
39-
void findCellsNeighbours(const int iteration) override;
40-
void findRoads(const int iteration) override;
34+
void adoptTimeFrame(TimeFrame* tf) final;
35+
void initialiseTimeFrame(const int iteration) final;
36+
void setBz(float) final;
4137

42-
// Methods to get CPU execution from traits
43-
void initialiseTimeFrameHybrid(const int iteration) override { initialiseTimeFrame(iteration); };
44-
void computeTrackletsHybrid(const int iteration, int, int) override;
45-
void computeCellsHybrid(const int iteration) override;
46-
void findCellsNeighboursHybrid(const int iteration) override;
38+
void computeLayerTracklets(const int iteration, int, int) final { LOGP(fatal, "computeLayerTracklers must never be called from Hybrid traits!"); };
39+
void computeLayerCells(const int iteration) final { LOGP(fatal, "computeLayerCells must never be called from Hybrid traits!"); };
40+
void findCellsNeighbours(const int iteration) final { LOGP(fatal, "findCellsNeighbours must never be called from Hybrid traits!"); };
41+
void findRoads(const int iteration) final { LOGP(fatal, "findRoads must never be called from Hybrid traits!"); };
42+
void extendTracks(const int iteration) final { LOGP(fatal, "extendTracks must never be called from Hybrid traits!"); };
43+
void findShortPrimaries() final { LOGP(fatal, "findShortPrimaries must never be called from Hybrid traits!"); };
4744

48-
void extendTracks(const int iteration) override;
45+
void initialiseTimeFrameHybrid(const int iteration) final { initialiseTimeFrame(iteration); };
46+
void computeTrackletsHybrid(const int iteration, int, int) final;
47+
void computeCellsHybrid(const int iteration) final;
48+
void findCellsNeighboursHybrid(const int iteration) final;
49+
void findRoadsHybrid(const int iteration) final;
4950

5051
// TimeFrameGPU information forwarding
5152
int getTFNumberOfClusters() const override;

Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackerTraitsGPU.cxx

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
///
1212

1313
#include <array>
14-
#include <sstream>
15-
#include <iostream>
1614
#include <unistd.h>
17-
#include <thread>
1815

1916
#include "DataFormatsITS/TrackITS.h"
2017

@@ -40,26 +37,6 @@ void TrackerTraitsGPU<nLayers>::initialiseTimeFrame(const int iteration)
4037
mTimeFrameGPU->loadIndexTableUtils(iteration);
4138
}
4239

43-
template <int nLayers>
44-
void TrackerTraitsGPU<nLayers>::computeLayerTracklets(const int iteration, int, int)
45-
{
46-
}
47-
48-
template <int nLayers>
49-
void TrackerTraitsGPU<nLayers>::computeLayerCells(const int iteration)
50-
{
51-
}
52-
53-
template <int nLayers>
54-
void TrackerTraitsGPU<nLayers>::findCellsNeighbours(const int iteration)
55-
{
56-
}
57-
58-
template <int nLayers>
59-
void TrackerTraitsGPU<nLayers>::extendTracks(const int iteration)
60-
{
61-
}
62-
6340
template <int nLayers>
6441
void TrackerTraitsGPU<nLayers>::setBz(float bz)
6542
{
@@ -260,7 +237,7 @@ void TrackerTraitsGPU<nLayers>::findCellsNeighboursHybrid(const int iteration)
260237
};
261238

262239
template <int nLayers>
263-
void TrackerTraitsGPU<nLayers>::findRoads(const int iteration)
240+
void TrackerTraitsGPU<nLayers>::findRoadsHybrid(const int iteration)
264241
{
265242
auto& conf = o2::its::ITSGpuTrackingParamConfig::Instance();
266243
for (int startLevel{mTrkParams[iteration].CellsPerRoad()}; startLevel >= mTrkParams[iteration].CellMinimumLevel(); --startLevel) {

Detectors/ITSMFT/ITS/tracking/include/ITStracking/Tracker.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,17 @@ class TrackerTraits;
5050

5151
class Tracker
5252
{
53+
using LogFunc = std::function<void(const std::string& s)>;
5354

5455
public:
5556
Tracker(TrackerTraits* traits);
5657

57-
Tracker(const Tracker&) = delete;
58-
Tracker& operator=(const Tracker&) = delete;
59-
~Tracker();
60-
6158
void adoptTimeFrame(TimeFrame& tf);
6259

6360
void clustersToTracks(
64-
std::function<void(std::string s)> = [](std::string s) { std::cout << s << std::endl; }, std::function<void(std::string s)> = [](std::string s) { std::cerr << s << std::endl; });
61+
LogFunc = [](std::string s) { std::cout << s << std::endl; }, LogFunc = [](std::string s) { std::cerr << s << std::endl; });
6562
void clustersToTracksHybrid(
66-
std::function<void(std::string s)> = [](std::string s) { std::cout << s << std::endl; }, std::function<void(std::string s)> = [](std::string s) { std::cerr << s << std::endl; });
63+
LogFunc = [](std::string s) { std::cout << s << std::endl; }, LogFunc = [](std::string s) { std::cerr << s << std::endl; });
6764
std::vector<TrackITSExt>& getTracks();
6865

6966
void setParameters(const std::vector<TrackingParameters>&);
@@ -74,41 +71,48 @@ class Tracker
7471
bool isMatLUT() const;
7572
void setNThreads(int n);
7673
int getNThreads() const;
77-
std::uint32_t mTimeFrameCounter = 0;
74+
void printSummary() const;
7875

7976
private:
77+
enum TrackerType : uint8_t { CPU = 0,
78+
Hybrid,
79+
NSize };
80+
template <TrackerType>
81+
void clusterToTracksImpl(LogFunc, LogFunc);
82+
static constexpr const char* sTrackerNames[TrackerType::NSize] = {"CPU", "Hybrid"};
83+
84+
// CPU
8085
void initialiseTimeFrame(int& iteration);
8186
void computeTracklets(int& iteration, int& iROFslice, int& iVertex);
8287
void computeCells(int& iteration);
8388
void findCellsNeighbours(int& iteration);
8489
void findRoads(int& iteration);
85-
90+
void findShortPrimaries();
91+
void extendTracks(int& iteration);
92+
// Hyrbid
8693
void initialiseTimeFrameHybrid(int& iteration);
8794
void computeTrackletsHybrid(int& iteration, int& iROFslice, int& iVertex);
8895
void computeCellsHybrid(int& iteration);
8996
void findCellsNeighboursHybrid(int& iteration);
9097
void findRoadsHybrid(int& iteration);
9198
void findTracksHybrid(int& iteration);
9299

93-
void findShortPrimaries();
94-
void findTracks();
95-
void extendTracks(int& iteration);
96-
97100
// MC interaction
98101
void computeRoadsMClabels();
99102
void computeTracksMClabels();
100103
void rectifyClusterIndices();
101104

102105
template <typename... T>
103-
float evaluateTask(void (Tracker::*)(T...), const char*, std::function<void(std::string s)> logger, T&&... args);
106+
float evaluateTask(void (Tracker::*)(T...), const char*, LogFunc logger, T&&... args);
104107

105108
TrackerTraits* mTraits = nullptr; /// Observer pointer, not owned by this class
106109
TimeFrame* mTimeFrame = nullptr; /// Observer pointer, not owned by this class
107110

108111
std::vector<TrackingParameters> mTrkParams;
109112
o2::gpu::GPUChainITS* mRecoChain = nullptr;
110113

111-
unsigned int mNumberOfRuns{0};
114+
unsigned int mNumberOfDroppedTFs{0};
115+
unsigned int mTimeFrameCounter{0};
112116
};
113117

114118
inline void Tracker::setParameters(const std::vector<TrackingParameters>& trkPars)
@@ -117,8 +121,7 @@ inline void Tracker::setParameters(const std::vector<TrackingParameters>& trkPar
117121
}
118122

119123
template <typename... T>
120-
float Tracker::evaluateTask(void (Tracker::*task)(T...), const char* taskName, std::function<void(std::string s)> logger,
121-
T&&... args)
124+
float Tracker::evaluateTask(void (Tracker::*task)(T...), const char* taskName, LogFunc logger, T&&... args)
122125
{
123126
float diff{0.f};
124127

Detectors/ITSMFT/ITS/tracking/include/ITStracking/TrackerTraits.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,21 @@ class TrackerTraits
5151
public:
5252
virtual ~TrackerTraits() = default;
5353
virtual void adoptTimeFrame(TimeFrame* tf);
54+
5455
virtual void initialiseTimeFrame(const int iteration);
5556
virtual void computeLayerTracklets(const int iteration, int iROFslice, int iVertex);
5657
virtual void computeLayerCells(const int iteration);
5758
virtual void findCellsNeighbours(const int iteration);
5859
virtual void findRoads(const int iteration);
59-
virtual void initialiseTimeFrameHybrid(const int iteration) { LOGP(error, "initialiseTimeFrameHybrid: this method should never be called with CPU traits"); }
60-
virtual void computeTrackletsHybrid(const int iteration, int, int) { LOGP(error, "computeTrackletsHybrid: this method should never be called with CPU traits"); }
61-
virtual void computeCellsHybrid(const int iteration) { LOGP(error, "computeCellsHybrid: this method should never be called with CPU traits"); }
62-
virtual void findCellsNeighboursHybrid(const int iteration) { LOGP(error, "findCellsNeighboursHybrid: this method should never be called with CPU traits"); }
63-
virtual void findRoadsHybrid(const int iteration) { LOGP(error, "findRoadsHybrid: this method should never be called with CPU traits"); }
64-
virtual void findTracksHybrid(const int iteration) { LOGP(error, "findTracksHybrid: this method should never be called with CPU traits"); }
65-
virtual void findTracks() { LOGP(error, "findTracks: this method is deprecated."); }
6660
virtual void extendTracks(const int iteration);
6761
virtual void findShortPrimaries();
62+
63+
virtual void initialiseTimeFrameHybrid(const int iteration) { LOGP(fatal, "initialiseTimeFrameHybrid: this method should never be called with CPU traits"); }
64+
virtual void computeTrackletsHybrid(const int iteration, int, int) { LOGP(fatal, "computeTrackletsHybrid: this method should never be called with CPU traits"); }
65+
virtual void computeCellsHybrid(const int iteration) { LOGP(fatal, "computeCellsHybrid: this method should never be called with CPU traits"); }
66+
virtual void findCellsNeighboursHybrid(const int iteration) { LOGP(fatal, "findCellsNeighboursHybrid: this method should never be called with CPU traits"); }
67+
virtual void findRoadsHybrid(const int iteration) { LOGP(fatal, "findRoadsHybrid: this method should never be called with CPU traits"); }
68+
6869
virtual void setBz(float bz);
6970
virtual bool trackFollowing(TrackITSExt* track, int rof, bool outward, const int iteration);
7071
virtual void processNeighbours(int iLayer, int iLevel, const std::vector<CellSeed>& currentCellSeed, const std::vector<int>& currentCellId, std::vector<CellSeed>& updatedCellSeed, std::vector<int>& updatedCellId);

Detectors/ITSMFT/ITS/tracking/include/ITStracking/TrackingInterface.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ class ITSTrackingInterface
3737
const bool overrBeamEst)
3838
: mIsMC{isMC},
3939
mUseTriggers{trgType},
40-
mOverrideBeamEstimation{overrBeamEst}
41-
{
42-
}
40+
mOverrideBeamEstimation{overrBeamEst} {}
4341

4442
void setClusterDictionary(const o2::itsmft::TopologyDictionary* d) { mDict = d; }
4543
void setMeanVertex(const o2::dataformats::MeanVertexObject* v)
@@ -56,6 +54,7 @@ class ITSTrackingInterface
5654
void initialise();
5755
template <bool isGPU = false>
5856
void run(framework::ProcessingContext& pc);
57+
void printSummary() const;
5958

6059
virtual void updateTimeDependentParams(framework::ProcessingContext& pc);
6160
virtual void finaliseCCDB(framework::ConcreteDataMatcher& matcher, void* obj);

0 commit comments

Comments
 (0)