Skip to content

Commit 051b0b3

Browse files
authored
ORT GPU implementation (#13755)
1 parent 922cad6 commit 051b0b3

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

Common/ML/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,21 @@
99
# granted to it by virtue of its status as an Intergovernmental Organization
1010
# or submit itself to any jurisdiction.
1111

12+
# Pass ORT variables as a preprocessor definition
13+
if(DEFINED ENV{ORT_ROCM_BUILD})
14+
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
15+
endif()
16+
if(DEFINED ENV{ORT_CUDA_BUILD})
17+
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
18+
endif()
19+
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
20+
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
21+
endif()
22+
if(DEFINED ENV{ORT_TENSORRT_BUILD})
23+
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
24+
endif()
25+
1226
o2_add_library(ML
13-
SOURCES src/ort_interface.cxx
27+
SOURCES src/OrtInterface.cxx
1428
TARGETVARNAME targetName
1529
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.h
12+
/// \file OrtInterface.h
1313
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#ifndef O2_ML_ONNX_INTERFACE_H
17-
#define O2_ML_ONNX_INTERFACE_H
16+
#ifndef O2_ML_ORTINTERFACE_H
17+
#define O2_ML_ORTINTERFACE_H
1818

1919
// C++ and system includes
2020
#include <vector>
@@ -89,4 +89,4 @@ class OrtModel
8989

9090
} // namespace o2
9191

92-
#endif // O2_ML_ORT_INTERFACE_H
92+
#endif // O2_ML_ORTINTERFACE_H
Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.cxx
12+
/// \file OrtInterface.cxx
1313
/// \author Christian Sonnabend <christian.sonnabend@cern.ch>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#include "ML/ort_interface.h"
16+
#include "ML/OrtInterface.h"
1717
#include "ML/3rdparty/GPUORTFloat16.h"
1818

1919
// ONNX includes
@@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5050
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
5151
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
5252
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53-
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
53+
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
5454
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
5555
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
5656

5757
std::string dev_mem_str = "Hip";
58-
#ifdef ORT_ROCM_BUILD
58+
#if defined(ORT_ROCM_BUILD)
59+
#if ORT_ROCM_BUILD == 1
5960
if (device == "ROCM") {
6061
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
6162
LOG(info) << "(ORT) ROCM execution provider set";
6263
}
6364
#endif
64-
#ifdef ORT_MIGRAPHX_BUILD
65+
#endif
66+
#if defined(ORT_MIGRAPHX_BUILD)
67+
#if ORT_MIGRAPHX_BUILD == 1
6568
if (device == "MIGRAPHX") {
6669
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
6770
LOG(info) << "(ORT) MIGraphX execution provider set";
6871
}
6972
#endif
70-
#ifdef ORT_CUDA_BUILD
73+
#endif
74+
#if defined(ORT_CUDA_BUILD)
75+
#if ORT_CUDA_BUILD == 1
7176
if (device == "CUDA") {
7277
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
7378
LOG(info) << "(ORT) CUDA execution provider set";
7479
dev_mem_str = "Cuda";
7580
}
81+
#endif
7682
#endif
7783

7884
if (allocateDeviceMemory) {
@@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
106112
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
107113
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
108114

109-
pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()));
115+
pImplOrt->env = std::make_shared<Ort::Env>(
116+
OrtLoggingLevel(loggingLevel),
117+
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
118+
// Integrate ORT logging into Fairlogger
119+
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
120+
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121+
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
122+
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
123+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
124+
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125+
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
126+
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127+
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
128+
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129+
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
130+
} else {
131+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
132+
}
133+
},
134+
(void*)3);
135+
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
110136
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
111137

112138
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
@@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
130156
[&](const std::string& str) { return str.c_str(); });
131157

132158
// Print names
133-
if (loggingLevel > 1) {
134-
LOG(info) << "Input Nodes:";
135-
for (size_t i = 0; i < mInputNames.size(); i++) {
136-
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
137-
}
159+
LOG(info) << "\tInput Nodes:";
160+
for (size_t i = 0; i < mInputNames.size(); i++) {
161+
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162+
}
138163

139-
LOG(info) << "Output Nodes:";
140-
for (size_t i = 0; i < mOutputNames.size(); i++) {
141-
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
142-
}
164+
LOG(info) << "\tOutput Nodes:";
165+
for (size_t i = 0; i < mOutputNames.size(); i++) {
166+
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
143167
}
144168
}
145169

0 commit comments

Comments
 (0)