Skip to content

Commit b5ab60d

Browse files
ChSonnabendalibuilddavidrohr
authored
GPU clusterizer with neural networks (#13981)
* Copying kernels to implement NN clusterizer * First version of clusterizer in GPU code * Adding a compiling and running version with single-threaded ONNX model executions. Clusters are not getting published yet (FIXME) * Clusters now working by a hack * Working implementation of settings via GPUSettings.h and --configKeyValues "GPU_proc.[setting]=...;..." * Modifying the onnx_interface to include the right headers * Adjusting initialization for new ONNXRuntime version * Adjusting global settings and CF code for several settings * Adding return statement if cluster is rejected * Adding some statements back * Update to latest status of gpu clusterization * Fixing uchar -> uint8_t * Adding utils header * Updating kernels.cmake to uint8_t * Please consider the following formatting changes * Adding an ONNX CPU library in the O2 framework * Please consider the following formatting changes * Fixing macOS build issues with calling O*.data() * Fixing compiler issues and char -> uint8_t * Fixing curly braces * Fixing std::make_shared * Changing order for <CommonUtils/StringUtils.h> * Bug-fixing file name * Making NN clusterizer more efficient * Changing constexpr * Fixing build issues * Major changes to make clusterizer parallelizable. Problem remains: different sizes of nnClusterizerBatchedMode lead to different number of clusters if nnClusterizerBatchedMode < clusterer.mPmemory->counters.nClusters * Adjusting for default CF regression * Bug-fix for application of CF regression and logging message * Adding is_boundary check earlier to avoid out-of-bounds access * Bug-fixes for boundary reading * Updating to use explicit calls to kernels instead of if-statements * Bug-fix for class label application * Explicit casting solves regression issues. To be done: Correct publishing for class2 regression * Bug-fixes * Adding some documentation * Please consider the following formatting changes * Modifying for Davids comments * Modifications from comments on PR * Please consider the following formatting changes * iSlice -> iSector * mISlice -> mISector * Minor bug-fixes * Adjusting for comments * Bug-fix for fullCI build * Adding GPUd() for on-device functions * Fixing compile issues, only thing mssing: conversion of float to float16 * Let's see if this does the trick * Making functions (constructors) GPUd() (GPUdDefault()) * GPU kernels should now be findable * Adding ifdefs for standalone build and header exclusions in GPUORTFloat16 * Modifying the approach to not use std:: types. Still needs to be tested and need to do proper memory allocation * New version of clusterizer. Compiles locally, but segfaults in fillInput kernel. Testing with the CI now. * Please consider the following formatting changes * Adjust for comments * Please consider the following formatting changes * Merging dev and adjusting build issues * Adjusting for comments * Fixing incorrect #endif * Please consider the following formatting changes * Fix indentation, remove duplicate define * Fixing one memory issue. Segfault / memory leak persists * Adjusting for new toNative function * Fixing .finalize * Adjusting CMakeLIsts and other bugs * Adding GPUCA_HAS_ONNX only to tracking * Changing to fixed size for number of clusters * Fixed segfault. Not producing the right number of clusters yet. * Network now accepts clusters over all sectors * Whitespaces... * Some weird formatting * Please consider the following formatting changes * Removing white-spaces * Adding necessary if-statement to avoid automatic model loading * Removing GPUConstantMem, adding interOpNumThreads option * Found the bug where I loose clusters * Editor configured for whitespaces at EOF --------- Co-authored-by: ALICE Action Bot <alibuild@cern.ch> Co-authored-by: David Rohr <github@jwdt.org>
1 parent 2626074 commit b5ab60d

20 files changed

+1002
-112
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class OrtModel
4141
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
4242
void init(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
4343
void reset(std::unordered_map<std::string, std::string>);
44+
bool isInitialized() { return mInitialized; }
4445

4546
virtual ~OrtModel() = default;
4647

@@ -55,6 +56,9 @@ class OrtModel
5556
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
5657
std::vector<O> inference(std::vector<std::vector<I>>&);
5758

59+
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
60+
void inference(I*, size_t, O*);
61+
5862
// template<class I, class T, class O> // class I is the input data type, e.g. float, class T the throughput data type and class O is the output data type
5963
// std::vector<O> inference(std::vector<I>&);
6064

@@ -79,8 +83,9 @@ class OrtModel
7983
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
8084

8185
// Environment settings
82-
std::string modelPath, device = "cpu", dtype = "float"; // device options should be cpu, rocm, migraphx, cuda
83-
int intraOpNumThreads = 0, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
86+
bool mInitialized = false;
87+
std::string modelPath, device = "cpu", dtype = "float", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
88+
int intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
8489

8590
std::string printShape(const std::vector<int64_t>&);
8691
};

Common/ML/src/OrtInterface.cxx

Lines changed: 65 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
4444
if (!optionsMap.contains("model-path")) {
4545
LOG(fatal) << "(ORT) Model path cannot be empty!";
4646
}
47-
modelPath = optionsMap["model-path"];
48-
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
49-
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
50-
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
51-
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
52-
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53-
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
54-
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
55-
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
56-
57-
std::string dev_mem_str = "Hip";
47+
48+
if (!optionsMap["model-path"].empty()) {
49+
modelPath = optionsMap["model-path"];
50+
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
51+
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
52+
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
53+
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
54+
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
55+
interOpNumThreads = (optionsMap.contains("inter-op-num-threads") ? std::stoi(optionsMap["inter-op-num-threads"]) : 0);
56+
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
57+
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
58+
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
59+
60+
std::string dev_mem_str = "Hip";
5861
#if defined(ORT_ROCM_BUILD)
5962
#if ORT_ROCM_BUILD == 1
6063
if (device == "ROCM") {
@@ -88,12 +91,15 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
8891

8992
if (device == "CPU") {
9093
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
91-
if (intraOpNumThreads > 1) {
94+
(pImplOrt->sessionOptions).SetInterOpNumThreads(interOpNumThreads);
95+
if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
9296
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
9397
} else if (intraOpNumThreads == 1) {
9498
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
9599
}
96-
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
100+
if (loggingLevel < 2) {
101+
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads";
102+
}
97103
}
98104

99105
(pImplOrt->sessionOptions).DisableMemPattern();
@@ -109,6 +115,9 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
109115
} else {
110116
(pImplOrt->sessionOptions).DisableProfiling();
111117
}
118+
119+
mInitialized = true;
120+
112121
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
113122
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
114123

@@ -154,16 +163,9 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
154163
outputNamesChar.resize(mOutputNames.size(), nullptr);
155164
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
156165
[&](const std::string& str) { return str.c_str(); });
157-
158-
// Print names
159-
LOG(info) << "\tInput Nodes:";
160-
for (size_t i = 0; i < mInputNames.size(); i++) {
161-
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162166
}
163-
164-
LOG(info) << "\tOutput Nodes:";
165-
for (size_t i = 0; i < mOutputNames.size(); i++) {
166-
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
167+
if (loggingLevel < 2) {
168+
LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
167169
}
168170
}
169171

@@ -187,36 +189,6 @@ std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
187189
}
188190
}
189191

190-
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
191-
std::vector<O> OrtModel::inference(std::vector<I>& input)
192-
{
193-
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
194-
std::vector<Ort::Value> inputTensor;
195-
inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
196-
// input.clear();
197-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
198-
O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
199-
std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
200-
outputTensors.clear();
201-
return outputValuesVec;
202-
}
203-
204-
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
205-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
206-
{
207-
std::vector<Ort::Value> inputTensor;
208-
for (auto i : input) {
209-
std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
210-
inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
211-
}
212-
// input.clear();
213-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
214-
O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
215-
std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
216-
outputTensors.clear();
217-
return outputValuesVec;
218-
}
219-
220192
std::string OrtModel::printShape(const std::vector<int64_t>& v)
221193
{
222194
std::stringstream ss("");
@@ -227,74 +199,68 @@ std::string OrtModel::printShape(const std::vector<int64_t>& v)
227199
return ss.str();
228200
}
229201

230-
template <>
231-
std::vector<float> OrtModel::inference<float, float>(std::vector<float>& input)
202+
template <class I, class O>
203+
std::vector<O> OrtModel::inference(std::vector<I>& input)
232204
{
233205
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
234206
std::vector<Ort::Value> inputTensor;
235-
inputTensor.emplace_back(Ort::Value::CreateTensor<float>(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
207+
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
208+
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
209+
} else {
210+
inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
211+
}
236212
// input.clear();
237213
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
238-
float* outputValues = outputTensors[0].template GetTensorMutableData<float>();
239-
std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
214+
O* outputValues = outputTensors[0].template GetTensorMutableData<O>();
215+
std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
240216
outputTensors.clear();
241217
return outputValuesVec;
242218
}
243219

244-
template <>
245-
std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>& input)
246-
{
247-
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
248-
std::vector<Ort::Value> inputTensor;
249-
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
250-
// input.clear();
251-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
252-
float* outputValues = outputTensors[0].template GetTensorMutableData<float>();
253-
std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
254-
outputTensors.clear();
255-
return outputValuesVec;
256-
}
220+
template std::vector<float> OrtModel::inference<float, float>(std::vector<float>&);
257221

258-
template <>
259-
std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>& input)
260-
{
261-
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
262-
std::vector<Ort::Value> inputTensor;
263-
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
264-
// input.clear();
265-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
266-
OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
267-
std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
268-
outputTensors.clear();
269-
return outputValuesVec;
270-
}
222+
template std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
271223

272-
template <>
273-
std::vector<OrtDataType::Float16_t> OrtModel::inference<float, OrtDataType::Float16_t>(std::vector<float>& input)
224+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
225+
226+
template <class I, class O>
227+
void OrtModel::inference(I* input, size_t input_size, O* output)
274228
{
275-
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
276-
std::vector<Ort::Value> inputTensor;
277-
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
278-
// input.clear();
279-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
280-
OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
281-
std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
282-
outputTensors.clear();
283-
return outputValuesVec;
229+
std::vector<int64_t> inputShape{(int64_t)(input_size / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
230+
Ort::Value inputTensor = Ort::Value(nullptr);
231+
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
232+
inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input), input_size, inputShape.data(), inputShape.size());
233+
} else {
234+
inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input, input_size, inputShape.data(), inputShape.size());
235+
}
236+
237+
std::vector<int64_t> outputShape{inputShape[0], mOutputShapes[0][1]};
238+
size_t outputSize = (int64_t)(input_size * mOutputShapes[0][1] / mInputShapes[0][1]);
239+
Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, output, outputSize, outputShape.data(), outputShape.size());
240+
241+
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is correct here
284242
}
285243

286-
template <>
287-
std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>& input)
244+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*);
245+
246+
template void OrtModel::inference<float, float>(float*, size_t, float*);
247+
248+
template <class I, class O>
249+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
288250
{
289251
std::vector<Ort::Value> inputTensor;
290252
for (auto i : input) {
291253
std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
292-
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
254+
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
255+
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
256+
} else {
257+
inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, i.data(), i.size(), inputShape.data(), inputShape.size()));
258+
}
293259
}
294260
// input.clear();
295261
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
296-
OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
297-
std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
262+
O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
263+
std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
298264
outputTensors.clear();
299265
return outputValuesVec;
300266
}

GPU/GPUTracking/Base/GPUConstantMem.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@
3434
#include "GPUKernelDebugOutput.h"
3535
#endif
3636

37+
#ifdef GPUCA_HAS_ONNX
38+
#include "GPUTPCNNClusterizer.h"
39+
#endif
40+
3741
namespace o2::gpu
3842
{
3943
struct GPUConstantMem {
4044
GPUParam param;
41-
GPUTPCTracker
42-
tpcTrackers[GPUCA_NSECTORS];
45+
GPUTPCTracker tpcTrackers[GPUCA_NSECTORS];
4346
GPUTPCConvert tpcConverter;
4447
GPUTPCCompression tpcCompressor;
4548
GPUTPCDecompression tpcDecompressor;
@@ -55,6 +58,9 @@ struct GPUConstantMem {
5558
#ifdef GPUCA_KERNEL_DEBUGGER_OUTPUT
5659
GPUKernelDebugOutput debugOutput;
5760
#endif
61+
#ifdef GPUCA_HAS_ONNX
62+
GPUTPCNNClusterizer tpcNNClusterer[GPUCA_NSECTORS];
63+
#endif
5864

5965
template <int32_t I>
6066
GPUd() auto& getTRDTracker();

GPU/GPUTracking/Base/GPUMemoryResource.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct GPUMemoryReuse {
2828
};
2929
enum Group : uint16_t {
3030
ClustererScratch,
31+
NNClusterer,
3132
ClustererZS,
3233
TrackerScratch,
3334
TrackerDataLinks,

GPU/GPUTracking/Base/GPUReconstruction.cxx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ GPUReconstruction::GPUReconstruction(const GPUSettingsDeviceBackend& cfg) : mHos
9393
for (uint32_t i = 0; i < NSECTORS; i++) {
9494
processors()->tpcTrackers[i].SetSector(i); // TODO: Move to a better place
9595
processors()->tpcClusterer[i].mISector = i;
96+
#ifdef GPUCA_HAS_ONNX
97+
processors()->tpcNNClusterer[i].mISector = i;
98+
#endif
9699
}
97100
#ifndef GPUCA_NO_ROOT
98101
mROOTDump = GPUROOTDumpCore::getAndCreate();

GPU/GPUTracking/CMakeLists.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ set(HDRS_INSTALL
159159
)
160160

161161
set(SRCS_NO_CINT ${SRCS_NO_CINT} display/GPUDisplayInterface.cxx)
162-
set(SRCS_NO_CINT
163-
${SRCS_NO_CINT}
162+
163+
set(SRCS_NO_CINT ${SRCS_NO_CINT}
164164
Global/GPUChainITS.cxx
165165
ITS/GPUITSFitter.cxx
166166
ITS/GPUITSFitterKernels.cxx
@@ -191,6 +191,10 @@ set(SRCS_NO_CINT
191191
Refit/GPUTrackingRefitKernel.cxx
192192
Merger/GPUTPCGMO2Output.cxx)
193193

194+
if(NOT ALIGPU_BUILD_TYPE STREQUAL "Standalone")
195+
list(APPEND SRCS_NO_CINT TPCClusterFinder/GPUTPCNNClusterizerKernels.cxx TPCClusterFinder/GPUTPCNNClusterizer.cxx TPCClusterFinder/GPUTPCNNClusterizerHost.cxx)
196+
endif()
197+
194198
set(SRCS_DATATYPES
195199
${SRCS_DATATYPES}
196200
DataTypes/TPCPadGainCalib.cxx
@@ -273,6 +277,7 @@ if(ALIGPU_BUILD_TYPE STREQUAL "O2")
273277
PRIVATE_LINK_LIBRARIES O2::DataFormatsTPC
274278
SOURCES ${SRCS_DATATYPES})
275279
target_compile_definitions(${targetName} PRIVATE GPUCA_O2_LIB GPUCA_TPC_GEOMETRY_O2)
280+
276281
o2_target_root_dictionary(GPUDataTypes
277282
HEADERS ${HDRS_CINT_DATATYPES} ${HDRS_CINT_O2_ADDITIONAL}
278283
LINKDEF GPUTrackingLinkDef_O2_DataTypes.h)
@@ -292,6 +297,7 @@ if(ALIGPU_BUILD_TYPE STREQUAL "O2")
292297
O2::TPCFastTransformation
293298
O2::DetectorsRaw
294299
O2::Steer
300+
O2::ML
295301
PUBLIC_INCLUDE_DIRECTORIES .
296302
Definitions
297303
DataTypes
@@ -317,7 +323,7 @@ if(ALIGPU_BUILD_TYPE STREQUAL "O2")
317323
${targetName}
318324
PRIVATE $<TARGET_PROPERTY:O2::Framework,INTERFACE_INCLUDE_DIRECTORIES>)
319325

320-
target_compile_definitions(${targetName} PRIVATE GPUCA_O2_LIB GPUCA_TPC_GEOMETRY_O2)
326+
target_compile_definitions(${targetName} PRIVATE GPUCA_O2_LIB GPUCA_TPC_GEOMETRY_O2 GPUCA_HAS_ONNX=1)
321327

322328
o2_target_root_dictionary(${MODULE}
323329
HEADERS ${HDRS_CINT_O2} ${HDRS_CINT_O2_ADDITIONAL}

0 commit comments

Comments
 (0)