@@ -35,16 +35,7 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
3535 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo(" Cpu" , OrtAllocatorType::OrtDeviceAllocator, 0 , OrtMemType::OrtMemTypeDefault);
3636};
3737
38- Ort::SessionOptions& OrtModel::updateSessionOptions ()
39- {
40- return pImplOrt->sessionOptions ;
41- }
42-
43- Ort::MemoryInfo& OrtModel::updateMemoryInfo ()
44- {
45- return pImplOrt->memoryInfo ;
46- }
47-
38+ // General purpose
4839void OrtModel::initOptions (std::unordered_map<std::string, std::string> optionsMap)
4940{
5041 pImplOrt = new OrtVariables ();
@@ -56,7 +47,8 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
5647
5748 if (!optionsMap[" model-path" ].empty ()) {
5849 modelPath = optionsMap[" model-path" ];
59- device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
50+ deviceType = (optionsMap.contains (" device-type" ) ? optionsMap[" device-type" ] : " CPU" );
51+ deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : -1 );
6052 allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
6153 intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
6254 interOpNumThreads = (optionsMap.contains (" inter-op-num-threads" ) ? std::stoi (optionsMap[" inter-op-num-threads" ]) : 0 );
@@ -65,7 +57,7 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
6557 enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
6658 envName = (optionsMap.contains (" onnx-environment-name" ) ? optionsMap[" onnx-environment-name" ] : " onnx_model_inference" );
6759
68- if (device == " CPU" ) {
60+ if (deviceType == " CPU" ) {
6961 (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
7062 (pImplOrt->sessionOptions ).SetInterOpNumThreads (interOpNumThreads);
7163 if (intraOpNumThreads > 1 || interOpNumThreads > 1 ) {
@@ -97,14 +89,18 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
9789
9890 (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
9991 (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
92+
93+ mInitialized = true ;
10094 } else {
10195 LOG (fatal) << " (ORT) Model path cannot be empty!" ;
10296 }
10397}
10498
10599void OrtModel::initEnvironment ()
106100{
107- mInitialized = true ;
101+ if (allocateDeviceMemory) {
102+ memoryOnDevice (deviceId);
103+ }
108104 pImplOrt->env = std::make_shared<Ort::Env>(
109105 OrtLoggingLevel (loggingLevel),
110106 (envName.empty () ? " ORT" : envName.c_str ()),
@@ -128,39 +124,48 @@ void OrtModel::initEnvironment()
128124 (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
129125 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
130126
127+ if (loggingLevel < 2 ) {
128+ LOG (info) << " (ORT) Model loaded successfully! (input: " << printShape (mInputShapes [0 ]) << " , output: " << printShape (mOutputShapes [0 ]) << " )" ;
129+ }
130+
131131 setIO ();
132132}
133133
134- void OrtModel::setIO () {
135- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
136- mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
137- }
138- for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
139- mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
140- }
141- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
142- mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
143- }
144- for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
145- mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
146- }
147-
148- inputNamesChar.resize (mInputNames .size (), nullptr );
149- std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
150- [&](const std::string& str) { return str.c_str (); });
151- outputNamesChar.resize (mOutputNames .size (), nullptr );
152- std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
153- [&](const std::string& str) { return str.c_str (); });
154- if (loggingLevel < 2 ) {
155- LOG (info) << " (ORT) Model loaded successfully! (input: " << printShape (mInputShapes [0 ]) << " , output: " << printShape (mOutputShapes [0 ]) << " )" ;
134+ void OrtModel::memoryOnDevice (int32_t deviceIndex)
135+ {
136+ #if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
137+ if (deviceIndex >= 0 ) {
138+ std::string dev_mem_str = " " ;
139+ if (deviceType == " ROCM" ) {
140+ dev_mem_str = " Hip" ;
141+ }
142+ if (deviceType == " CUDA" ) {
143+ dev_mem_str = " Cuda" ;
144+ }
145+ pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
146+ if (loggingLevel < 2 ) {
147+ LOG (info) << " (ORT) Memory info set to on-device memory for device type " << deviceType << " with ID " << deviceIndex;
148+ }
156149 }
150+ #endif
157151}
158152
159153void OrtModel::resetSession ()
160154{
161155 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
162156}
163157
158+ // Getters
159+ Ort::SessionOptions& OrtModel::getSessionOptions ()
160+ {
161+ return pImplOrt->sessionOptions ;
162+ }
163+
164+ Ort::MemoryInfo& OrtModel::getMemoryInfo ()
165+ {
166+ return pImplOrt->memoryInfo ;
167+ }
168+
164169template <class I , class O >
165170std::vector<O> OrtModel::v2v (std::vector<I>& input, bool clearInput)
166171{
@@ -176,32 +181,32 @@ std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
176181 }
177182}
178183
179- std::string OrtModel::printShape (const std::vector<int64_t >& v)
180- {
181- std::stringstream ss (" " );
182- for (size_t i = 0 ; i < v.size () - 1 ; i++) {
183- ss << v[i] << " x" ;
184+ void OrtModel::setIO () {
185+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
186+ mInputNames .push_back ((pImplOrt->session )->GetInputNameAllocated (i, pImplOrt->allocator ).get ());
184187 }
185- ss << v[v.size () - 1 ];
186- return ss.str ();
188+ for (size_t i = 0 ; i < (pImplOrt->session )->GetInputCount (); ++i) {
189+ mInputShapes .emplace_back ((pImplOrt->session )->GetInputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
190+ }
191+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
192+ mOutputNames .push_back ((pImplOrt->session )->GetOutputNameAllocated (i, pImplOrt->allocator ).get ());
193+ }
194+ for (size_t i = 0 ; i < (pImplOrt->session )->GetOutputCount (); ++i) {
195+ mOutputShapes .emplace_back ((pImplOrt->session )->GetOutputTypeInfo (i).GetTensorTypeAndShapeInfo ().GetShape ());
196+ }
197+
198+ inputNamesChar.resize (mInputNames .size (), nullptr );
199+ std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
200+ [&](const std::string& str) { return str.c_str (); });
201+ outputNamesChar.resize (mOutputNames .size (), nullptr );
202+ std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
203+ [&](const std::string& str) { return str.c_str (); });
187204}
188205
206+ // Inference
189207template <class I , class O >
190- std::vector<O> OrtModel::inference (std::vector<I>& input, int32_t deviceIndex )
208+ std::vector<O> OrtModel::inference (std::vector<I>& input)
191209{
192- #if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
193- if (allocateDeviceMemory) {
194- std::string dev_mem_str = " " ;
195- if (device == " ROCM" ) {
196- dev_mem_str = " Hip" ;
197- }
198- if (device == " CUDA" ) {
199- dev_mem_str = " Cuda" ;
200- }
201- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
202- LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex;
203- }
204- #endif
205210 std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
206211 std::vector<Ort::Value> inputTensor;
207212 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
@@ -217,32 +222,19 @@ std::vector<O> OrtModel::inference(std::vector<I>& input, int32_t deviceIndex)
217222 return outputValuesVec;
218223}
219224
220- template std::vector<float > OrtModel::inference<float , float >(std::vector<float >&, int32_t );
225+ template std::vector<float > OrtModel::inference<float , float >(std::vector<float >&);
221226
222- template std::vector<float > OrtModel::inference<OrtDataType::Float16_t, float >(std::vector<OrtDataType::Float16_t>&, int32_t );
227+ template std::vector<float > OrtModel::inference<OrtDataType::Float16_t, float >(std::vector<OrtDataType::Float16_t>&);
223228
224- template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&, int32_t );
229+ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
225230
226231template <class I , class O >
227- void OrtModel::inference (I* input, size_t input_size, O* output, int32_t deviceIndex )
232+ void OrtModel::inference (I* input, size_t input_size, O* output)
228233{
229234 // std::vector<std::string> providers = Ort::GetAvailableProviders();
230235 // for (const auto& provider : providers) {
231236 // LOG(info) << "Available Execution Provider: " << provider;
232237 // }
233- #if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
234- if (allocateDeviceMemory) {
235- std::string dev_mem_str = " " ;
236- if (device == " ROCM" ) {
237- dev_mem_str = " Hip" ;
238- }
239- if (device == " CUDA" ) {
240- dev_mem_str = " Cuda" ;
241- }
242- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
243- LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex;
244- }
245- #endif
246238 std::vector<int64_t > inputShape{input_size, (int64_t )mInputShapes [0 ][1 ]};
247239 Ort::Value inputTensor = Ort::Value (nullptr );
248240 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
@@ -257,26 +249,13 @@ void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceI
257249 (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), &inputTensor, 1 , outputNamesChar.data (), &outputTensor, outputNamesChar.size ());
258250}
259251
260- template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t*, size_t , float *, int32_t );
252+ template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t*, size_t , float *);
261253
262- template void OrtModel::inference<float , float >(float *, size_t , float *, int32_t );
254+ template void OrtModel::inference<float , float >(float *, size_t , float *);
263255
264256template <class I , class O >
265- std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& input, int32_t deviceIndex )
257+ std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& input)
266258{
267- #if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
268- if (allocateDeviceMemory) {
269- std::string dev_mem_str = " " ;
270- if (device == " ROCM" ) {
271- dev_mem_str = " Hip" ;
272- }
273- if (device == " CUDA" ) {
274- dev_mem_str = " Cuda" ;
275- }
276- pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
277- LOG (info) << " (ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex;
278- }
279- #endif
280259 std::vector<Ort::Value> inputTensor;
281260 for (auto i : input) {
282261 std::vector<int64_t > inputShape{(int64_t )(i.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
@@ -294,6 +273,17 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input, int32_t d
294273 return outputValuesVec;
295274}
296275
276+ // private
277+ std::string OrtModel::printShape (const std::vector<int64_t >& v)
278+ {
279+ std::stringstream ss (" " );
280+ for (size_t i = 0 ; i < v.size () - 1 ; i++) {
281+ ss << v[i] << " x" ;
282+ }
283+ ss << v[v.size () - 1 ];
284+ return ss.str ();
285+ }
286+
297287} // namespace ml
298288
299289} // namespace o2
0 commit comments