2121#define TOOLS_ML_MODEL_H_
2222
2323// C++ and system includes
24- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
25- #include < onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
26- #else
2724#include < onnxruntime_cxx_api.h>
28- #endif
2925#include < vector>
3026#include < string>
3127#include < memory>
@@ -62,9 +58,6 @@ class OnnxModel
6258 // assert(input[0].GetTensorTypeAndShapeInfo().GetShape() == getNumInputNodes()); --> Fails build in debug mode, TODO: assertion should be checked somehow
6359
6460 try {
65- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
66- auto outputTensors = mSession ->Run (mInputNames , input, mOutputNames );
67- #else
6861 Ort::RunOptions runOptions;
6962 std::vector<const char *> inputNamesChar (mInputNames .size (), nullptr );
7063 std::transform (std::begin (mInputNames ), std::end (mInputNames ), std::begin (inputNamesChar),
@@ -74,7 +67,6 @@ class OnnxModel
7467 std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
7568 [&](const std::string& str) { return str.c_str (); });
7669 auto outputTensors = mSession ->Run (runOptions, inputNamesChar.data (), input.data (), input.size (), outputNamesChar.data (), outputNamesChar.size ());
77- #endif
7870 LOG (debug) << " Number of output tensors: " << outputTensors.size ();
7971 if (outputTensors.size () != mOutputNames .size ()) {
8072 LOG (fatal) << " Number of output tensors: " << outputTensors.size () << " does not agree with the model specified size: " << mOutputNames .size ();
@@ -100,13 +92,9 @@ class OnnxModel
10092 assert (size % mInputShapes [0 ][1 ] == 0 );
10193 std::vector<int64_t > inputShape{size / mInputShapes [0 ][1 ], mInputShapes [0 ][1 ]};
10294 std::vector<Ort::Value> inputTensors;
103- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
104- inputTensors.emplace_back (Ort::Experimental::Value::CreateTensor<T>(input.data (), size, inputShape));
105- #else
10695 Ort::MemoryInfo memInfo =
10796 Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
10897 inputTensors.emplace_back (Ort::Value::CreateTensor<T>(memInfo, input.data (), size, inputShape.data (), inputShape.size ()));
109- #endif
11098 LOG (debug) << " Input shape calculated from vector: " << printShape (inputShape);
11199 return evalModel<T>(inputTensors);
112100 }
@@ -117,9 +105,7 @@ class OnnxModel
117105 {
118106 std::vector<Ort::Value> inputTensors;
119107
120- #if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
121108 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
122- #endif
123109
124110 for (size_t iinput = 0 ; iinput < input.size (); iinput++) {
125111 [[maybe_unused]] int totalSize = 1 ;
@@ -134,36 +120,24 @@ class OnnxModel
134120 inputShape.push_back (mInputShapes [iinput][idim]);
135121 }
136122
137- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
138- inputTensors.emplace_back (Ort::Experimental::Value::CreateTensor<T>(input[iinput].data (), size, inputShape));
139- #else
140123 inputTensors.emplace_back (Ort::Value::CreateTensor<T>(memInfo, input[iinput].data (), size, inputShape.data (), inputShape.size ()));
141- #endif
142124 }
143125
144126 return evalModel<T>(inputTensors);
145127 }
146128
147129 // Reset session
148- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
149- void resetSession () { mSession .reset (new Ort::Experimental::Session{*mEnv , modelPath, sessionOptions}); }
150- #else
151130 void resetSession ()
152131 {
153132 mSession .reset (new Ort::Session{*mEnv , modelPath.c_str (), sessionOptions});
154133 }
155- #endif
156134
157135 // Getters & Setters
158136 Ort::SessionOptions* getSessionOptions () { return &sessionOptions; } // For optimizations in post
159- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
160- std::shared_ptr<Ort::Experimental::Session> getSession () { return mSession ; }
161- #else
162137 std::shared_ptr<Ort::Session> getSession ()
163138 {
164139 return mSession ;
165140 }
166- #endif
167141 int getNumInputNodes () const { return mInputShapes [0 ][1 ]; }
168142 std::vector<std::vector<int64_t >> getInputShapes () const { return mInputShapes ; }
169143 int getNumOutputNodes () const { return mOutputShapes [0 ][1 ]; }
@@ -174,11 +148,7 @@ class OnnxModel
174148 private:
175149 // Environment variables for the ONNX runtime
176150 std::shared_ptr<Ort::Env> mEnv = nullptr ;
177- #if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
178- std::shared_ptr<Ort::Experimental::Session> mSession = nullptr ;
179- #else
180151 std::shared_ptr<Ort::Session> mSession = nullptr ;
181- #endif
182152 Ort::SessionOptions sessionOptions;
183153
184154 // Input & Output specifications of the loaded network
0 commit comments