@@ -35,19 +35,13 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
3535 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo(" Cpu" , OrtAllocatorType::OrtDeviceAllocator, 0 , OrtMemType::OrtMemTypeDefault);
3636};
3737
38- Ort::SessionOptions* OrtModel::updateSessionOptions ()
38+ Ort::SessionOptions& OrtModel::updateSessionOptions ()
3939{
40- return &( pImplOrt->sessionOptions ) ;
40+ return pImplOrt->sessionOptions ;
4141}
4242
43- Ort::MemoryInfo* OrtModel::updateMemoryInfo ( )
43+ void OrtModel::initOptions (std::unordered_map<std::string, std::string> optionsMap )
4444{
45- return &(pImplOrt->memoryInfo );
46- }
47-
48- void OrtModel::reset (std::unordered_map<std::string, std::string> optionsMap)
49- {
50-
5145 pImplOrt = new OrtVariables ();
5246
5347 // Load from options map
@@ -58,71 +52,57 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5852 if (!optionsMap[" model-path" ].empty ()) {
5953 modelPath = optionsMap[" model-path" ];
6054 device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
61- deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
6255 allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
6356 intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
6457 interOpNumThreads = (optionsMap.contains (" inter-op-num-threads" ) ? std::stoi (optionsMap[" inter-op-num-threads" ]) : 0 );
6558 loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 0 );
6659 enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
6760 enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
68-
69- // #if defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1
70- // if (device == "ROCM") {
71- // // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
72- // SetONNXGPUStream(pImplOrt->sessionOptions, deviceId);
73- // LOG(info) << "(ORT) ROCM execution provider set";
74- // }
75- // #endif
76- // #if defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1
77- // if (device == "MIGRAPHX") {
78- // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
79- // LOG(info) << "(ORT) MIGraphX execution provider set";
80- // }
81- // #endif
82- // #if defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1
83- // if (device == "CUDA") {
84- // // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
85- // SetONNXGPUStream(pImplOrt->sessionOptions, deviceId);
86- // LOG(info) << "(ORT) CUDA execution provider set";
87- // dev_mem_str = "Cuda";
88- // }
89- // #endif
90-
91- if (device == " CPU" ) {
92- (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
93- (pImplOrt->sessionOptions ).SetInterOpNumThreads (interOpNumThreads);
94- if (intraOpNumThreads > 1 || interOpNumThreads > 1 ) {
95- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
96- } else if (intraOpNumThreads == 1 ) {
97- (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
98- }
99- if (loggingLevel < 2 ) {
100- LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads" ;
61+ envName = (optionsMap.contains (" onnx-environment-name" ) ? optionsMap[" onnx-environment-name" ] : " onnx_model_inference" );
62+
63+ if (device == " CPU" ) {
64+ (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
65+ (pImplOrt->sessionOptions ).SetInterOpNumThreads (interOpNumThreads);
66+ if (intraOpNumThreads > 1 || interOpNumThreads > 1 ) {
67+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
68+ } else if (intraOpNumThreads == 1 ) {
69+ (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
70+ }
71+ if (loggingLevel < 2 ) {
72+ LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads" ;
73+ }
10174 }
102- }
10375
104- (pImplOrt-> sessionOptions ). DisableMemPattern () ;
105- (pImplOrt->sessionOptions ).DisableCpuMemArena ( );
76+ // OrtROCMProviderOptions rocm_options{} ;
77+ // (pImplOrt->sessionOptions).AppendExecutionProvider_ROCM(rocm_options );
10678
107- if (enableProfiling) {
108- if (optionsMap.contains (" profiling-output-path" )) {
109- (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
79+ (pImplOrt->sessionOptions ).DisableMemPattern ();
80+ (pImplOrt->sessionOptions ).DisableCpuMemArena ();
81+
82+ if (enableProfiling) {
83+ if (optionsMap.contains (" profiling-output-path" )) {
84+ (pImplOrt->sessionOptions ).EnableProfiling ((optionsMap[" profiling-output-path" ] + " /ORT_LOG_" ).c_str ());
85+ } else {
86+ LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
87+ (pImplOrt->sessionOptions ).DisableProfiling ();
88+ }
11089 } else {
111- LOG (warning) << " (ORT) If profiling is enabled, optionsMap[\" profiling-output-path\" ] should be set. Disabling profiling for now." ;
11290 (pImplOrt->sessionOptions ).DisableProfiling ();
11391 }
92+
93+ (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
94+ (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
11495 } else {
115- (pImplOrt-> sessionOptions ). DisableProfiling () ;
96+ LOG (fatal) << " (ORT) Model path cannot be empty! " ;
11697 }
98+ }
11799
100+ void OrtModel::initEnvironment ()
101+ {
118102 mInitialized = true ;
119-
120- (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
121- (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
122-
123103 pImplOrt->env = std::make_shared<Ort::Env>(
124104 OrtLoggingLevel (loggingLevel),
125- (optionsMap[ " onnx-environment-name " ] .empty () ? " onnx_model_inference " : optionsMap[ " onnx-environment-name " ] .c_str ()),
105+ (envName .empty () ? " ORT " : envName .c_str ()),
126106 // Integrate ORT logging into Fairlogger
127107 [](void * param, OrtLoggingLevel severity, const char * category, const char * logid, const char * code_location, const char * message) {
128108 if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
@@ -143,6 +123,10 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
143123 (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
144124 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
145125
126+ setIO ();
127+ }
128+
129+ void OrtModel::setIO () {
146130 for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
147131 mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
148132 }
@@ -162,7 +146,6 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
162146 outputNamesChar.resize (mOutputNames .size (), nullptr );
163147 std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
164148 [&](const std::string& str) { return str.c_str (); });
165- }
166149 if (loggingLevel < 2 ) {
167150 LOG (info) << " (ORT) Model loaded successfully! (input: " << printShape (mInputShapes [0 ]) << " , output: " << printShape (mOutputShapes [0 ]) << " )" ;
168151 }
@@ -203,18 +186,15 @@ std::vector<O> OrtModel::inference(std::vector<I>& input, int32_t deviceIndex)
203186{
204187#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
205188 if (allocateDeviceMemory) {
206- if (deviceIndex >= 0 ) {
207- deviceId = deviceIndex;
208- }
209189 std::string dev_mem_str = " " ;
210190 if (device == " ROCM" ) {
211191 dev_mem_str = " Hip" ;
212192 }
213193 if (device == " CUDA" ) {
214194 dev_mem_str = " Cuda" ;
215195 }
216- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId , OrtMemType::OrtMemTypeDefault);
217- LOG (info) << " (ORT) Memory info set to on-device memory" ;
196+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex , OrtMemType::OrtMemTypeDefault);
197+ LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex ;
218198 }
219199#endif
220200 std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
@@ -241,20 +221,21 @@ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Fl
241221template <class I , class O >
242222void OrtModel::inference (I* input, size_t input_size, O* output, int32_t deviceIndex)
243223{
224+ // std::vector<std::string> providers = Ort::GetAvailableProviders();
225+ // for (const auto& provider : providers) {
226+ // LOG(info) << "Available Execution Provider: " << provider;
227+ // }
244228#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
245229 if (allocateDeviceMemory) {
246- if (deviceIndex >= 0 ) {
247- deviceId = deviceIndex;
248- }
249230 std::string dev_mem_str = " " ;
250231 if (device == " ROCM" ) {
251232 dev_mem_str = " Hip" ;
252233 }
253234 if (device == " CUDA" ) {
254235 dev_mem_str = " Cuda" ;
255236 }
256- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId , OrtMemType::OrtMemTypeDefault);
257- LOG (info) << " (ORT) Memory info set to on-device memory" ;
237+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex , OrtMemType::OrtMemTypeDefault);
238+ LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex ;
258239 }
259240#endif
260241 std::vector<int64_t > inputShape{input_size, (int64_t )mInputShapes [0 ][1 ]};
@@ -268,7 +249,7 @@ void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceI
268249 std::vector<int64_t > outputShape{input_size, mOutputShapes [0 ][1 ]};
269250 Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo , output, input_size * mOutputShapes [0 ][1 ], outputShape.data (), outputShape.size ());
270251
271- (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), &inputTensor, 1 , outputNamesChar.data (), &outputTensor, outputNamesChar.size ()); // TODO: Not sure if 1 is always correct here
252+ (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), &inputTensor, 1 , outputNamesChar.data (), &outputTensor, outputNamesChar.size ());
272253}
273254
274255template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t*, size_t , float *, int32_t );
@@ -280,18 +261,15 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input, int32_t d
280261{
281262#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
282263 if (allocateDeviceMemory) {
283- if (deviceIndex >= 0 ) {
284- deviceId = deviceIndex;
285- }
286264 std::string dev_mem_str = " " ;
287265 if (device == " ROCM" ) {
288266 dev_mem_str = " Hip" ;
289267 }
290268 if (device == " CUDA" ) {
291269 dev_mem_str = " Cuda" ;
292270 }
293- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceId , OrtMemType::OrtMemTypeDefault);
294- LOG (info) << " (ORT) Memory info set to on-device memory" ;
271+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex , OrtMemType::OrtMemTypeDefault);
272+ LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex ;
295273 }
296274#endif
297275 std::vector<Ort::Value> inputTensor;
0 commit comments