Skip to content

Commit 95bb2ff

Browse files
committed
Major changes to make clusterizer parallelizable. Problem remains: different sizes of nnClusterizerBatchedMode lead to different number of clusters if nnClusterizerBatchedMode < clusterer.mPmemory->counters.nClusters
1 parent 3c4c587 commit 95bb2ff

File tree

9 files changed

+445
-712
lines changed

9 files changed

+445
-712
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class OrtModel
4141
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
4242
void init(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
4343
void reset(std::unordered_map<std::string, std::string>);
44+
bool isInitialized() { return mInitialized; }
4445

4546
virtual ~OrtModel() = default;
4647

@@ -79,6 +80,7 @@ class OrtModel
7980
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
8081

8182
// Environment settings
83+
bool mInitialized = false;
8284
std::string modelPath, device = "cpu", dtype = "float"; // device options should be cpu, rocm, migraphx, cuda
8385
int intraOpNumThreads = 0, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
8486

@@ -89,4 +91,4 @@ class OrtModel
8991

9092
} // namespace o2
9193

92-
#endif // O2_ML_ORTINTERFACE_H
94+
#endif // O2_ML_ORTINTERFACE_H

Common/ML/src/OrtInterface.cxx

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
4444
if (!optionsMap.contains("model-path")) {
4545
LOG(fatal) << "(ORT) Model path cannot be empty!";
4646
}
47-
modelPath = optionsMap["model-path"];
48-
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
49-
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
50-
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
51-
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
52-
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53-
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
54-
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
55-
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
56-
57-
std::string dev_mem_str = "Hip";
47+
48+
if (!optionsMap["model-path"].empty()) {
49+
modelPath = optionsMap["model-path"];
50+
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
51+
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
52+
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
53+
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
54+
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
55+
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
56+
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
57+
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
58+
59+
std::string dev_mem_str = "Hip";
5860
#if defined(ORT_ROCM_BUILD)
5961
#if ORT_ROCM_BUILD == 1
6062
if (device == "ROCM") {
@@ -81,89 +83,85 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
8183
#endif
8284
#endif
8385

84-
if (allocateDeviceMemory) {
85-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
86-
LOG(info) << "(ORT) Memory info set to on-device memory";
87-
}
86+
if (allocateDeviceMemory) {
87+
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
88+
LOG(info) << "(ORT) Memory info set to on-device memory";
89+
}
8890

89-
if (device == "CPU") {
90-
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
91-
if (intraOpNumThreads > 1) {
92-
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
93-
} else if (intraOpNumThreads == 1) {
94-
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
91+
if (device == "CPU") {
92+
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
93+
if (intraOpNumThreads > 1) {
94+
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
95+
} else if (intraOpNumThreads == 1) {
96+
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
97+
}
98+
if (loggingLevel < 2) {
99+
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
100+
}
95101
}
96-
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
97-
}
98102

99-
(pImplOrt->sessionOptions).DisableMemPattern();
100-
(pImplOrt->sessionOptions).DisableCpuMemArena();
103+
(pImplOrt->sessionOptions).DisableMemPattern();
104+
(pImplOrt->sessionOptions).DisableCpuMemArena();
101105

102-
if (enableProfiling) {
103-
if (optionsMap.contains("profiling-output-path")) {
104-
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
106+
if (enableProfiling) {
107+
if (optionsMap.contains("profiling-output-path")) {
108+
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
109+
} else {
110+
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
111+
(pImplOrt->sessionOptions).DisableProfiling();
112+
}
105113
} else {
106-
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
107114
(pImplOrt->sessionOptions).DisableProfiling();
108115
}
109-
} else {
110-
(pImplOrt->sessionOptions).DisableProfiling();
111-
}
112-
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
113-
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
114-
115-
pImplOrt->env = std::make_shared<Ort::Env>(
116-
OrtLoggingLevel(loggingLevel),
117-
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
118-
// Integrate ORT logging into Fairlogger
119-
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
120-
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121-
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
122-
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
123-
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
124-
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125-
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
126-
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127-
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
128-
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129-
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
130-
} else {
131-
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
132-
}
133-
},
134-
(void*)3);
135-
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
136-
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
137116

138-
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
139-
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
140-
}
141-
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
142-
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
143-
}
144-
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
145-
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
146-
}
147-
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
148-
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
149-
}
117+
mInitialized = true;
150118

151-
inputNamesChar.resize(mInputNames.size(), nullptr);
152-
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
153-
[&](const std::string& str) { return str.c_str(); });
154-
outputNamesChar.resize(mOutputNames.size(), nullptr);
155-
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
156-
[&](const std::string& str) { return str.c_str(); });
157-
158-
// Print names
159-
LOG(info) << "\tInput Nodes:";
160-
for (size_t i = 0; i < mInputNames.size(); i++) {
161-
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162-
}
119+
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
120+
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
121+
122+
pImplOrt->env = std::make_shared<Ort::Env>(
123+
OrtLoggingLevel(loggingLevel),
124+
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
125+
// Integrate ORT logging into Fairlogger
126+
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
127+
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
128+
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
129+
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
130+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
131+
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
132+
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
133+
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
134+
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
135+
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
136+
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
137+
} else {
138+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
139+
}
140+
},
141+
(void*)3);
142+
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
143+
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
144+
145+
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
146+
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
147+
}
148+
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
149+
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
150+
}
151+
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
152+
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
153+
}
154+
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
155+
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
156+
}
157+
158+
inputNamesChar.resize(mInputNames.size(), nullptr);
159+
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
160+
[&](const std::string& str) { return str.c_str(); });
161+
outputNamesChar.resize(mOutputNames.size(), nullptr);
162+
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
163+
[&](const std::string& str) { return str.c_str(); });
163164

164-
LOG(info) << "\tOutput Nodes:";
165-
for (size_t i = 0; i < mOutputNames.size(); i++) {
166-
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
167165
}
168166
}
169167

@@ -301,4 +299,4 @@ std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t,
301299

302300
} // namespace ml
303301

304-
} // namespace o2
302+
} // namespace o2

GPU/GPUTracking/CMakeLists.txt

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -159,37 +159,37 @@ set(HDRS_INSTALL
159159

160160
set(SRCS_NO_CINT ${SRCS_NO_CINT} display/GPUDisplayInterface.cxx)
161161
set(SRCS_NO_CINT
162-
${SRCS_NO_CINT}
163-
Global/GPUChainITS.cxx
164-
ITS/GPUITSFitter.cxx
165-
ITS/GPUITSFitterKernels.cxx
166-
dEdx/GPUdEdx.cxx
167-
TPCConvert/GPUTPCConvert.cxx
168-
TPCConvert/GPUTPCConvertKernel.cxx
169-
DataCompression/GPUTPCCompression.cxx
170-
DataCompression/GPUTPCCompressionTrackModel.cxx
171-
DataCompression/GPUTPCCompressionKernels.cxx
172-
DataCompression/GPUTPCDecompression.cxx
173-
DataCompression/GPUTPCDecompressionKernels.cxx
174-
DataCompression/TPCClusterDecompressor.cxx
175-
DataCompression/GPUTPCClusterStatistics.cxx
176-
TPCClusterFinder/GPUTPCClusterFinder.cxx
177-
TPCClusterFinder/ClusterAccumulator.cxx
178-
TPCClusterFinder/MCLabelAccumulator.cxx
179-
TPCClusterFinder/GPUTPCCFCheckPadBaseline.cxx
180-
TPCClusterFinder/GPUTPCCFStreamCompaction.cxx
181-
TPCClusterFinder/GPUTPCCFChargeMapFiller.cxx
182-
TPCClusterFinder/GPUTPCCFPeakFinder.cxx
183-
TPCClusterFinder/GPUTPCCFNoiseSuppression.cxx
184-
TPCClusterFinder/GPUTPCCFClusterizer.cxx
185-
TPCClusterFinder/GPUTPCNNClusterizer.cxx
186-
TPCClusterFinder/GPUTPCCFDeconvolution.cxx
187-
TPCClusterFinder/GPUTPCCFMCLabelFlattener.cxx
188-
TPCClusterFinder/GPUTPCCFDecodeZS.cxx
189-
TPCClusterFinder/GPUTPCCFGather.cxx
190-
Refit/GPUTrackingRefit.cxx
191-
Refit/GPUTrackingRefitKernel.cxx
192-
Merger/GPUTPCGMO2Output.cxx)
162+
${SRCS_NO_CINT}
163+
Global/GPUChainITS.cxx
164+
ITS/GPUITSFitter.cxx
165+
ITS/GPUITSFitterKernels.cxx
166+
dEdx/GPUdEdx.cxx
167+
TPCConvert/GPUTPCConvert.cxx
168+
TPCConvert/GPUTPCConvertKernel.cxx
169+
DataCompression/GPUTPCCompression.cxx
170+
DataCompression/GPUTPCCompressionTrackModel.cxx
171+
DataCompression/GPUTPCCompressionKernels.cxx
172+
DataCompression/GPUTPCDecompression.cxx
173+
DataCompression/GPUTPCDecompressionKernels.cxx
174+
DataCompression/TPCClusterDecompressor.cxx
175+
DataCompression/GPUTPCClusterStatistics.cxx
176+
TPCClusterFinder/GPUTPCClusterFinder.cxx
177+
TPCClusterFinder/ClusterAccumulator.cxx
178+
TPCClusterFinder/MCLabelAccumulator.cxx
179+
TPCClusterFinder/GPUTPCCFCheckPadBaseline.cxx
180+
TPCClusterFinder/GPUTPCCFStreamCompaction.cxx
181+
TPCClusterFinder/GPUTPCCFChargeMapFiller.cxx
182+
TPCClusterFinder/GPUTPCCFPeakFinder.cxx
183+
TPCClusterFinder/GPUTPCCFNoiseSuppression.cxx
184+
TPCClusterFinder/GPUTPCCFClusterizer.cxx
185+
TPCClusterFinder/GPUTPCNNClusterizer.cxx
186+
TPCClusterFinder/GPUTPCCFDeconvolution.cxx
187+
TPCClusterFinder/GPUTPCCFMCLabelFlattener.cxx
188+
TPCClusterFinder/GPUTPCCFDecodeZS.cxx
189+
TPCClusterFinder/GPUTPCCFGather.cxx
190+
Refit/GPUTrackingRefit.cxx
191+
Refit/GPUTrackingRefitKernel.cxx
192+
Merger/GPUTPCGMO2Output.cxx)
193193

194194
set(SRCS_DATATYPES
195195
${SRCS_DATATYPES}

0 commit comments

Comments
 (0)