99// granted to it by virtue of its status as an Intergovernmental Organization
1010// or submit itself to any jurisdiction.
1111
12- // / \file ort_interface .cxx
12+ // / \file OrtInterface .cxx
1313// / \author Christian Sonnabend <christian.sonnabend@cern.ch>
1414// / \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515
16- #include " ML/ort_interface .h"
16+ #include " ML/OrtInterface .h"
1717#include " ML/3rdparty/GPUORTFloat16.h"
1818
1919// ONNX includes
@@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5050 deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
5151 allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
5252 intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
53- loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 0 );
53+ loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 2 );
5454 enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
5555 enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
5656
5757 std::string dev_mem_str = " Hip" ;
58- #ifdef ORT_ROCM_BUILD
58+ #if defined(ORT_ROCM_BUILD)
59+ #if ORT_ROCM_BUILD == 1
5960 if (device == " ROCM" ) {
6061 Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_ROCM (pImplOrt->sessionOptions , deviceId));
6162 LOG (info) << " (ORT) ROCM execution provider set" ;
6263 }
6364#endif
64- #ifdef ORT_MIGRAPHX_BUILD
65+ #endif
66+ #if defined(ORT_MIGRAPHX_BUILD)
67+ #if ORT_MIGRAPHX_BUILD == 1
6568 if (device == " MIGRAPHX" ) {
6669 Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_MIGraphX (pImplOrt->sessionOptions , deviceId));
6770 LOG (info) << " (ORT) MIGraphX execution provider set" ;
6871 }
6972#endif
70- #ifdef ORT_CUDA_BUILD
73+ #endif
74+ #if defined(ORT_CUDA_BUILD)
75+ #if ORT_CUDA_BUILD == 1
7176 if (device == " CUDA" ) {
7277 Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA (pImplOrt->sessionOptions , deviceId));
7378 LOG (info) << " (ORT) CUDA execution provider set" ;
7479 dev_mem_str = " Cuda" ;
7580 }
81+ #endif
7682#endif
7783
7884 if (allocateDeviceMemory) {
@@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
106112 (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
107113 (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
108114
109- pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel (loggingLevel), (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()));
115+ pImplOrt->env = std::make_shared<Ort::Env>(
116+ OrtLoggingLevel (loggingLevel),
117+ (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()),
118+ // Integrate ORT logging into Fairlogger
119+ [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
120+ if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121+ LOG (debug) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
122+ } else if (severity == ORT_LOGGING_LEVEL_INFO) {
123+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
124+ } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125+ LOG (warning) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
126+ } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127+ LOG (error) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
128+ } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129+ LOG (fatal) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
130+ } else {
131+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
132+ }
133+ },
134+ (void *)3 );
135+ (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
110136 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
111137
112138 for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
@@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
130156 [&](const std::string& str) { return str.c_str (); });
131157
132158 // Print names
133- if (loggingLevel > 1 ) {
134- LOG (info) << " Input Nodes:" ;
135- for (size_t i = 0 ; i < mInputNames .size (); i++) {
136- LOG (info) << " \t " << mInputNames [i] << " : " << printShape (mInputShapes [i]);
137- }
159+ LOG (info) << " \t Input Nodes:" ;
160+ for (size_t i = 0 ; i < mInputNames .size (); i++) {
161+ LOG (info) << " \t\t " << mInputNames [i] << " : " << printShape (mInputShapes [i]);
162+ }
138163
139- LOG (info) << " Output Nodes:" ;
140- for (size_t i = 0 ; i < mOutputNames .size (); i++) {
141- LOG (info) << " \t " << mOutputNames [i] << " : " << printShape (mOutputShapes [i]);
142- }
164+ LOG (info) << " \t Output Nodes:" ;
165+ for (size_t i = 0 ; i < mOutputNames .size (); i++) {
166+ LOG (info) << " \t\t " << mOutputNames [i] << " : " << printShape (mOutputShapes [i]);
143167 }
144168}
145169
0 commit comments