@@ -44,7 +44,7 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
4444 if (!optionsMap.contains (" model-path" )) {
4545 LOG (fatal) << " (ORT) Model path cannot be empty!" ;
4646 }
47-
47+
4848 if (!optionsMap[" model-path" ].empty ()) {
4949 modelPath = optionsMap[" model-path" ];
5050 device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
@@ -83,85 +83,84 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
8383#endif
8484#endif
8585
86- if (allocateDeviceMemory) {
87- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
88- LOG (info) << " (ORT) Memory info set to on-device memory" ;
89- }
86+ if (allocateDeviceMemory) {
87+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
88+ LOG (info) << " (ORT) Memory info set to on-device memory" ;
89+ }
9090
91- if (device == " CPU" ) {
92- (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
93- if (intraOpNumThreads > 1 ) {
94- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
95- } else if (intraOpNumThreads == 1 ) {
96- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
97- }
98- if (loggingLevel < 2 ) {
99- LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " threads" ;
100- }
91+ if (device == " CPU" ) {
92+ (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
93+ if (intraOpNumThreads > 1 ) {
94+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
95+ } else if (intraOpNumThreads == 1 ) {
96+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
97+ }
98+ if (loggingLevel < 2 ) {
99+ LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " threads" ;
101100 }
101+ }
102102
103- (pImplOrt->sessionOptions ).DisableMemPattern ();
104- (pImplOrt->sessionOptions ).DisableCpuMemArena ();
103+ (pImplOrt->sessionOptions ).DisableMemPattern ();
104+ (pImplOrt->sessionOptions ).DisableCpuMemArena ();
105105
106- if (enableProfiling) {
107- if (optionsMap.contains (" profiling-output-path" )) {
108- (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
109- } else {
110- LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
111- (pImplOrt->sessionOptions ).DisableProfiling ();
112- }
106+ if (enableProfiling) {
107+ if (optionsMap.contains (" profiling-output-path" )) {
108+ (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
113109 } else {
110+ LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
114111 (pImplOrt->sessionOptions ).DisableProfiling ();
115112 }
113+ } else {
114+ (pImplOrt->sessionOptions ).DisableProfiling ();
115+ }
116116
117- mInitialized = true ;
118-
119- (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
120- (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
121-
122- pImplOrt->env = std::make_shared<Ort::Env>(
123- OrtLoggingLevel (loggingLevel),
124- (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()),
125- // Integrate ORT logging into Fairlogger
126- [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
127- if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
128- LOG (debug) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
129- } else if (severity == ORT_LOGGING_LEVEL_INFO) {
130- LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
131- } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
132- LOG (warning) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
133- } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
134- LOG (error) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
135- } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
136- LOG (fatal) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
137- } else {
138- LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
139- }
140- },
141- (void *)3 );
142- (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
143- pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
144-
145- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
146- mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
147- }
148- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
149- mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
150- }
151- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
152- mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
153- }
154- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
155- mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
156- }
117+ mInitialized = true ;
118+
119+ (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
120+ (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
121+
122+ pImplOrt->env = std::make_shared<Ort::Env>(
123+ OrtLoggingLevel (loggingLevel),
124+ (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()),
125+ // Integrate ORT logging into Fairlogger
126+ [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
127+ if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
128+ LOG (debug) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
129+ } else if (severity == ORT_LOGGING_LEVEL_INFO) {
130+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
131+ } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
132+ LOG (warning) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
133+ } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
134+ LOG (error) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
135+ } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
136+ LOG (fatal) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
137+ } else {
138+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
139+ }
140+ },
141+ (void *)3 );
142+ (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
143+ pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
157144
158- inputNamesChar.resize (mInputNames .size (), nullptr );
159- std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
160- [&](const std::string& str) { return str.c_str (); });
161- outputNamesChar.resize (mOutputNames .size (), nullptr );
162- std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
163- [&](const std::string& str) { return str.c_str (); });
145+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
146+ mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
147+ }
148+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
149+ mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
150+ }
151+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
152+ mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
153+ }
154+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
155+ mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
156+ }
164157
158+ inputNamesChar.resize (mInputNames .size (), nullptr );
159+ std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
160+ [&](const std::string& str) { return str.c_str (); });
161+ outputNamesChar.resize (mOutputNames .size (), nullptr );
162+ std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
163+ [&](const std::string& str) { return str.c_str (); });
165164 }
166165}
167166
0 commit comments