@@ -124,11 +124,11 @@ void OrtModel::initEnvironment()
124124 (pImplOrt->env )->DisableTelemetryEvents (); // Disable telemetry events
125125 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env ), modelPath.c_str (), pImplOrt->sessionOptions );
126126
127+ setIO ();
128+
127129 if (loggingLevel < 2 ) {
128- LOG (info) << " (ORT) Model loaded successfully! (input : " << printShape (mInputShapes [ 0 ] ) << " , output : " << printShape (mOutputShapes [ 0 ] ) << " )" ;
130+ LOG (info) << " (ORT) Model loaded successfully! (inputs : " << printShape (mInputShapes , mInputNames ) << " , outputs : " << printShape (mOutputShapes , mInputNames ) << " )" ;
129131 }
130-
131- setIO ();
132132}
133133
134134void OrtModel::memoryOnDevice (int32_t deviceIndex)
@@ -201,13 +201,45 @@ void OrtModel::setIO() {
201201 outputNamesChar.resize (mOutputNames .size (), nullptr );
202202 std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
203203 [&](const std::string& str) { return str.c_str (); });
204+
205+ inputShapesCopy = mInputShapes ;
206+ outputShapesCopy = mOutputShapes ;
207+ inputSizePerNode.resize (mInputShapes .size (), 1 );
208+ outputSizePerNode.resize (mOutputShapes .size (), 1 );
209+ mInputsTotal = 1 ;
210+ for (size_t i = 0 ; i < mInputShapes .size (); ++i) {
211+ if (mInputShapes [i].size () > 0 ) {
212+ for (size_t j = 1 ; j < mInputShapes [i].size (); ++j) {
213+ if (mInputShapes [i][j] > 0 ) {
214+ mInputsTotal *= mInputShapes [i][j];
215+ inputSizePerNode[i] *= mInputShapes [i][j];
216+ }
217+ }
218+ }
219+ }
220+ mOutputsTotal = 1 ;
221+ for (size_t i = 0 ; i < mOutputShapes .size (); ++i) {
222+ if (mOutputShapes [i].size () > 0 ) {
223+ for (size_t j = 1 ; j < mOutputShapes [i].size (); ++j) {
224+ if (mOutputShapes [i][j] > 0 ) {
225+ mOutputsTotal *= mOutputShapes [i][j];
226+ outputSizePerNode[i] *= mOutputShapes [i][j];
227+ }
228+ }
229+ }
230+ }
204231}
205232
206233// Inference
207234template <class I , class O >
208235std::vector<O> OrtModel::inference (std::vector<I>& input)
209236{
210- std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
237+ std::vector<int64_t > inputShape = mInputShapes [0 ];
238+ inputShape[0 ] = input.size ();
239+ for (size_t i = 1 ; i < mInputShapes [0 ].size (); ++i)
240+ {
241+ inputShape[0 ] /= mInputShapes [0 ][i];
242+ }
211243 std::vector<Ort::Value> inputTensor;
212244 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
213245 inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
@@ -223,9 +255,7 @@ std::vector<O> OrtModel::inference(std::vector<I>& input)
223255}
224256
225257template std::vector<float > OrtModel::inference<float , float >(std::vector<float >&);
226-
227258template std::vector<float > OrtModel::inference<OrtDataType::Float16_t, float >(std::vector<OrtDataType::Float16_t>&);
228-
229259template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
230260
231261template <class I , class O >
@@ -255,33 +285,119 @@ void OrtModel::inference(I* input, size_t input_size, O* output)
255285}
256286
257287template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, size_t , OrtDataType::Float16_t*);
258-
259288template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t*, size_t , float *);
260-
261289template void OrtModel::inference<float , OrtDataType::Float16_t>(float *, size_t , OrtDataType::Float16_t*);
262-
263290template void OrtModel::inference<float , float >(float *, size_t , float *);
264291
265292template <class I , class O >
266- std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& input)
267- {
268- std::vector<Ort::Value> inputTensor;
269- for (auto i : input) {
270- std::vector<int64_t > inputShape{(int64_t )(i.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
293+ void OrtModel::inference (I** input, size_t input_size, O* output) {
294+ std::vector<Ort::Value> inputTensors (inputShapesCopy.size ());
295+
296+ for (size_t i = 0 ; i < inputShapesCopy.size (); ++i) {
297+
298+ inputShapesCopy[i][0 ] = input_size; // batch-size
299+ outputShapesCopy[i][0 ] = input_size; // batch-size
300+
271301 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
272- inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(i.data ()), i.size (), inputShape.data (), inputShape.size ()));
302+ inputTensors[i] = Ort::Value::CreateTensor<Ort::Float16_t>(
303+ pImplOrt->memoryInfo ,
304+ reinterpret_cast <Ort::Float16_t*>(input[i]),
305+ inputSizePerNode[i]*input_size,
306+ inputShapesCopy[i].data (),
307+ inputShapesCopy[i].size ());
273308 } else {
274- inputTensor.emplace_back (Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo , i.data (), i.size (), inputShape.data (), inputShape.size ()));
309+ inputTensors[i] = Ort::Value::CreateTensor<I>(
310+ pImplOrt->memoryInfo ,
311+ input[i],
312+ inputSizePerNode[i]*input_size,
313+ inputShapesCopy[i].data (),
314+ inputShapesCopy[i].size ());
275315 }
276316 }
277- // input.clear();
278- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
279- O* outputValues = reinterpret_cast <O*>(outputTensors[0 ].template GetTensorMutableData <O>());
280- std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size () / mInputShapes [0 ][1 ] * mOutputShapes [0 ][1 ]};
281- outputTensors.clear ();
282- return outputValuesVec;
317+
318+ Ort::Value outputTensor = Ort::Value (nullptr );
319+ if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
320+ outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(
321+ pImplOrt->memoryInfo ,
322+ reinterpret_cast <Ort::Float16_t*>(output),
323+ outputSizePerNode[0 ]*input_size, // assumes that there is only one output node
324+ outputShapesCopy[0 ].data (),
325+ outputShapesCopy[0 ].size ());
326+ } else {
327+ outputTensor = Ort::Value::CreateTensor<O>(
328+ pImplOrt->memoryInfo ,
329+ output,
330+ outputSizePerNode[0 ]*input_size, // assumes that there is only one output node
331+ outputShapesCopy[0 ].data (),
332+ outputShapesCopy[0 ].size ());
333+ }
334+
335+ // === Run inference ===
336+ pImplOrt->session ->Run (
337+ pImplOrt->runOptions ,
338+ inputNamesChar.data (),
339+ inputTensors.data (),
340+ inputNamesChar.size (),
341+ outputNamesChar.data (),
342+ &outputTensor,
343+ outputNamesChar.size ()
344+ );
345+ }
346+
347+ template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, size_t , OrtDataType::Float16_t*);
348+ template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t**, size_t , float *);
349+ template void OrtModel::inference<float , OrtDataType::Float16_t>(float **, size_t , OrtDataType::Float16_t*);
350+ template void OrtModel::inference<float , float >(float **, size_t , float *);
351+
352+ template <class I , class O >
353+ std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& inputs)
354+ {
355+ std::vector<Ort::Value> input_tensors;
356+
357+ for (size_t i = 0 ; i < inputs.size (); ++i) {
358+
359+ inputShapesCopy[i][0 ] = inputs[i].size () / inputSizePerNode[i]; // batch-size
360+
361+ if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
362+ input_tensors.emplace_back (
363+ Ort::Value::CreateTensor<Ort::Float16_t>(
364+ pImplOrt->memoryInfo ,
365+ reinterpret_cast <Ort::Float16_t*>(inputs[i].data ()),
366+ inputSizePerNode[i]*inputShapesCopy[i][0 ],
367+ inputShapesCopy[i].data (),
368+ inputShapesCopy[i].size ()));
369+ } else {
370+ input_tensors.emplace_back (
371+ Ort::Value::CreateTensor<I>(
372+ pImplOrt->memoryInfo ,
373+ inputs[i].data (),
374+ inputSizePerNode[i]*inputShapesCopy[i][0 ],
375+ inputShapesCopy[i].data (),
376+ inputShapesCopy[i].size ()));
377+ }
378+ }
379+
380+ int32_t totalOutputSize = mOutputsTotal *inputShapesCopy[0 ][0 ];
381+
382+ // === Run inference ===
383+ auto output_tensors = pImplOrt->session ->Run (
384+ pImplOrt->runOptions ,
385+ inputNamesChar.data (),
386+ input_tensors.data (),
387+ input_tensors.size (),
388+ outputNamesChar.data (),
389+ outputNamesChar.size ());
390+
391+ // === Extract output values ===
392+ O* output_data = output_tensors[0 ].template GetTensorMutableData <O>();
393+ std::vector<O> output_vec (output_data, output_data + totalOutputSize);
394+ output_tensors.clear ();
395+ return output_vec;
283396}
284397
398+ template std::vector<float > OrtModel::inference<float , float >(std::vector<std::vector<float >>&);
399+ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
400+
285401// private
286402std::string OrtModel::printShape (const std::vector<int64_t >& v)
287403{
@@ -293,6 +409,19 @@ std::string OrtModel::printShape(const std::vector<int64_t>& v)
293409 return ss.str ();
294410}
295411
412+ std::string OrtModel::printShape (const std::vector<std::vector<int64_t >>& v, std::vector<std::string>& n)
413+ {
414+ std::stringstream ss (" " );
415+ for (size_t i = 0 ; i < v.size (); i++) {
416+ ss << n[i] << " -> (" ;
417+ for (size_t j = 0 ; j < v[i].size () - 1 ; j++) {
418+ ss << v[i][j] << " x" ;
419+ }
420+ ss << v[i][v[i].size () - 1 ] << " ); " ;
421+ }
422+ return ss.str ();
423+ }
424+
296425} // namespace ml
297426
298427} // namespace o2
0 commit comments