@@ -44,17 +44,19 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
4444 if (!optionsMap.contains (" model-path" )) {
4545 LOG (fatal) << " (ORT) Model path cannot be empty!" ;
4646 }
47- modelPath = optionsMap[" model-path" ];
48- device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
49- dtype = (optionsMap.contains (" dtype" ) ? optionsMap[" dtype" ] : " float" );
50- deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
51- allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
52- intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
53- loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 2 );
54- enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
55- enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
56-
57- std::string dev_mem_str = " Hip" ;
47+
48+ if (!optionsMap[" model-path" ].empty ()) {
49+ modelPath = optionsMap[" model-path" ];
50+ device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
51+ dtype = (optionsMap.contains (" dtype" ) ? optionsMap[" dtype" ] : " float" );
52+ deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
53+ allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
54+ intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
55+ loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 0 );
56+ enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
57+ enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
58+
59+ std::string dev_mem_str = " Hip" ;
5860#if defined(ORT_ROCM_BUILD)
5961#if ORT_ROCM_BUILD == 1
6062 if (device == " ROCM" ) {
@@ -81,89 +83,85 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
8183#endif
8284#endif
8385
84- if (allocateDeviceMemory) {
85- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
86- LOG (info) << " (ORT) Memory info set to on-device memory" ;
87- }
86+ if (allocateDeviceMemory) {
87+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
88+ LOG (info) << " (ORT) Memory info set to on-device memory" ;
89+ }
8890
89- if (device == " CPU" ) {
90- (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
91- if (intraOpNumThreads > 1 ) {
92- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
93- } else if (intraOpNumThreads == 1 ) {
94- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
91+ if (device == " CPU" ) {
92+ (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
93+ if (intraOpNumThreads > 1 ) {
94+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
95+ } else if (intraOpNumThreads == 1 ) {
96+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
97+ }
98+ if (loggingLevel < 2 ) {
99+ LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " threads" ;
100+ }
95101 }
96- LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " threads" ;
97- }
98102
99- (pImplOrt->sessionOptions ).DisableMemPattern ();
100- (pImplOrt->sessionOptions ).DisableCpuMemArena ();
103+ (pImplOrt->sessionOptions ).DisableMemPattern ();
104+ (pImplOrt->sessionOptions ).DisableCpuMemArena ();
101105
102- if (enableProfiling) {
103- if (optionsMap.contains (" profiling-output-path" )) {
104- (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
106+ if (enableProfiling) {
107+ if (optionsMap.contains (" profiling-output-path" )) {
108+ (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
109+ } else {
110+ LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
111+ (pImplOrt->sessionOptions ).DisableProfiling ();
112+ }
105113 } else {
106- LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
107114 (pImplOrt->sessionOptions ).DisableProfiling ();
108115 }
109- } else {
110- (pImplOrt->sessionOptions ).DisableProfiling ();
111- }
112- (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
113- (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
114-
115- pImplOrt->env = std::make_shared<Ort::Env>(
116- OrtLoggingLevel (loggingLevel),
117- (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()),
118- // Integrate ORT logging into Fairlogger
119- [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
120- if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121- LOG (debug) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
122- } else if (severity == ORT_LOGGING_LEVEL_INFO) {
123- LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
124- } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125- LOG (warning) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
126- } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127- LOG (error) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
128- } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129- LOG (fatal) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
130- } else {
131- LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
132- }
133- },
134- (void *)3 );
135- (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
136- pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
137116
138- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
139- mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
140- }
141- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
142- mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
143- }
144- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
145- mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
146- }
147- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
148- mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
149- }
117+ mInitialized = true ;
150118
151- inputNamesChar.resize (mInputNames .size (), nullptr );
152- std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
153- [&](const std::string& str) { return str.c_str (); });
154- outputNamesChar.resize (mOutputNames .size (), nullptr );
155- std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
156- [&](const std::string& str) { return str.c_str (); });
157-
158- // Print names
159- LOG (info) << " \t Input Nodes:" ;
160- for (size_t i = 0 ; i < mInputNames .size (); i++) {
161- LOG (info) << " \t\t " << mInputNames [i] << " : " << printShape (mInputShapes [i]);
162- }
119+ (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
120+ (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
121+
122+ pImplOrt->env = std::make_shared<Ort::Env>(
123+ OrtLoggingLevel (loggingLevel),
124+ (optionsMap[" onnx-environment-name" ].empty () ? " onnx_model_inference" : optionsMap[" onnx-environment-name" ].c_str ()),
125+ // Integrate ORT logging into Fairlogger
126+ [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
127+ if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
128+ LOG (debug) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
129+ } else if (severity == ORT_LOGGING_LEVEL_INFO) {
130+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
131+ } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
132+ LOG (warning) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
133+ } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
134+ LOG (error) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
135+ } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
136+ LOG (fatal) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
137+ } else {
138+ LOG (info) << " (ORT) [" << logid << " |" << category << " |" << code_location << " ]: " << message;
139+ }
140+ },
141+ (void *)3 );
142+ (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
143+ pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
144+
145+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
146+ mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
147+ }
148+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
149+ mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
150+ }
151+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
152+ mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
153+ }
154+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
155+ mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
156+ }
157+
158+ inputNamesChar.resize (mInputNames .size (), nullptr );
159+ std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
160+ [&](const std::string& str) { return str.c_str (); });
161+ outputNamesChar.resize (mOutputNames .size (), nullptr );
162+ std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
163+ [&](const std::string& str) { return str.c_str (); });
163164
164- LOG (info) << " \t Output Nodes:" ;
165- for (size_t i = 0 ; i < mOutputNames .size (); i++) {
166- LOG (info) << " \t\t " << mOutputNames [i] << " : " << printShape (mOutputShapes [i]);
167165 }
168166}
169167
@@ -301,4 +299,4 @@ std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t,
301299
302300} // namespace ml
303301
304- } // namespace o2
302+ } // namespace o2
0 commit comments