Skip to content

Commit 710993a

Browse files
committed
Adding I** inference, potentally needed for CNN + FC inference
1 parent b437e38 commit 710993a

File tree

2 files changed

+158
-23
lines changed

2 files changed

+158
-23
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class OrtModel
8888
template <class I, class O>
8989
void inference(I*, size_t, O*);
9090

91+
template <class I, class O>
92+
void inference(I**, size_t, O*);
93+
9194
private:
9295
// ORT variables -> need to be hidden as pImpl
9396
struct OrtVariables;
@@ -96,14 +99,17 @@ class OrtModel
9699
// Input & Output specifications of the loaded network
97100
std::vector<const char*> inputNamesChar, outputNamesChar;
98101
std::vector<std::string> mInputNames, mOutputNames;
99-
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
102+
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes, inputShapesCopy, outputShapesCopy; // Input shapes
103+
std::vector<int64_t> inputSizePerNode, outputSizePerNode; // Output shapes
104+
int32_t mInputsTotal = 0, mOutputsTotal = 0; // Total number of inputs and outputs
100105

101106
// Environment settings
102107
bool mInitialized = false;
103108
std::string modelPath, envName = "", deviceType = "CPU", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
104109
int32_t intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = -1, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
105110

106111
std::string printShape(const std::vector<int64_t>&);
112+
std::string printShape(const std::vector<std::vector<int64_t>>&, std::vector<std::string>&);
107113
};
108114

109115
} // namespace ml

Common/ML/src/OrtInterface.cxx

Lines changed: 151 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ void OrtModel::initEnvironment()
124124
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
125125
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
126126

127+
setIO();
128+
127129
if (loggingLevel < 2) {
128-
LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
130+
LOG(info) << "(ORT) Model loaded successfully! (inputs: " << printShape(mInputShapes, mInputNames) << ", outputs: " << printShape(mOutputShapes, mInputNames) << ")";
129131
}
130-
131-
setIO();
132132
}
133133

134134
void OrtModel::memoryOnDevice(int32_t deviceIndex)
@@ -201,13 +201,45 @@ void OrtModel::setIO() {
201201
outputNamesChar.resize(mOutputNames.size(), nullptr);
202202
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
203203
[&](const std::string& str) { return str.c_str(); });
204+
205+
inputShapesCopy = mInputShapes;
206+
outputShapesCopy = mOutputShapes;
207+
inputSizePerNode.resize(mInputShapes.size(), 1);
208+
outputSizePerNode.resize(mOutputShapes.size(), 1);
209+
mInputsTotal = 1;
210+
for (size_t i = 0; i < mInputShapes.size(); ++i) {
211+
if(mInputShapes[i].size() > 0) {
212+
for (size_t j = 1; j < mInputShapes[i].size(); ++j) {
213+
if (mInputShapes[i][j] > 0) {
214+
mInputsTotal *= mInputShapes[i][j];
215+
inputSizePerNode[i] *= mInputShapes[i][j];
216+
}
217+
}
218+
}
219+
}
220+
mOutputsTotal = 1;
221+
for (size_t i = 0; i < mOutputShapes.size(); ++i) {
222+
if(mOutputShapes[i].size() > 0) {
223+
for (size_t j = 1; j < mOutputShapes[i].size(); ++j) {
224+
if (mOutputShapes[i][j] > 0) {
225+
mOutputsTotal *= mOutputShapes[i][j];
226+
outputSizePerNode[i] *= mOutputShapes[i][j];
227+
}
228+
}
229+
}
230+
}
204231
}
205232

206233
// Inference
207234
template <class I, class O>
208235
std::vector<O> OrtModel::inference(std::vector<I>& input)
209236
{
210-
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
237+
std::vector<int64_t> inputShape = mInputShapes[0];
238+
inputShape[0] = input.size();
239+
for (size_t i = 1; i < mInputShapes[0].size(); ++i)
240+
{
241+
inputShape[0] /= mInputShapes[0][i];
242+
}
211243
std::vector<Ort::Value> inputTensor;
212244
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
213245
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
@@ -223,9 +255,7 @@ std::vector<O> OrtModel::inference(std::vector<I>& input)
223255
}
224256

225257
template std::vector<float> OrtModel::inference<float, float>(std::vector<float>&);
226-
227258
template std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
228-
229259
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
230260

231261
template <class I, class O>
@@ -255,33 +285,119 @@ void OrtModel::inference(I* input, size_t input_size, O* output)
255285
}
256286

257287
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, size_t, OrtDataType::Float16_t*);
258-
259288
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*);
260-
261289
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, size_t, OrtDataType::Float16_t*);
262-
263290
template void OrtModel::inference<float, float>(float*, size_t, float*);
264291

265292
template <class I, class O>
266-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
267-
{
268-
std::vector<Ort::Value> inputTensor;
269-
for (auto i : input) {
270-
std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
293+
void OrtModel::inference(I** input, size_t input_size, O* output) {
294+
std::vector<Ort::Value> inputTensors(inputShapesCopy.size());
295+
296+
for (size_t i = 0; i < inputShapesCopy.size(); ++i) {
297+
298+
inputShapesCopy[i][0] = input_size; // batch-size
299+
outputShapesCopy[i][0] = input_size; // batch-size
300+
271301
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
272-
inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
302+
inputTensors[i] = Ort::Value::CreateTensor<Ort::Float16_t>(
303+
pImplOrt->memoryInfo,
304+
reinterpret_cast<Ort::Float16_t*>(input[i]),
305+
inputSizePerNode[i]*input_size,
306+
inputShapesCopy[i].data(),
307+
inputShapesCopy[i].size());
273308
} else {
274-
inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, i.data(), i.size(), inputShape.data(), inputShape.size()));
309+
inputTensors[i] = Ort::Value::CreateTensor<I>(
310+
pImplOrt->memoryInfo,
311+
input[i],
312+
inputSizePerNode[i]*input_size,
313+
inputShapesCopy[i].data(),
314+
inputShapesCopy[i].size());
275315
}
276316
}
277-
// input.clear();
278-
auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
279-
O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
280-
std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
281-
outputTensors.clear();
282-
return outputValuesVec;
317+
318+
Ort::Value outputTensor = Ort::Value(nullptr);
319+
if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
320+
outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(
321+
pImplOrt->memoryInfo,
322+
reinterpret_cast<Ort::Float16_t*>(output),
323+
outputSizePerNode[0]*input_size, // assumes that there is only one output node
324+
outputShapesCopy[0].data(),
325+
outputShapesCopy[0].size());
326+
} else {
327+
outputTensor = Ort::Value::CreateTensor<O>(
328+
pImplOrt->memoryInfo,
329+
output,
330+
outputSizePerNode[0]*input_size, // assumes that there is only one output node
331+
outputShapesCopy[0].data(),
332+
outputShapesCopy[0].size());
333+
}
334+
335+
// === Run inference ===
336+
pImplOrt->session->Run(
337+
pImplOrt->runOptions,
338+
inputNamesChar.data(),
339+
inputTensors.data(),
340+
inputNamesChar.size(),
341+
outputNamesChar.data(),
342+
&outputTensor,
343+
outputNamesChar.size()
344+
);
345+
}
346+
347+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, size_t, OrtDataType::Float16_t*);
348+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, size_t, float*);
349+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, size_t, OrtDataType::Float16_t*);
350+
template void OrtModel::inference<float, float>(float**, size_t, float*);
351+
352+
template <class I, class O>
353+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
354+
{
355+
std::vector<Ort::Value> input_tensors;
356+
357+
for (size_t i = 0; i < inputs.size(); ++i) {
358+
359+
inputShapesCopy[i][0] = inputs[i].size() / inputSizePerNode[i]; // batch-size
360+
361+
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
362+
input_tensors.emplace_back(
363+
Ort::Value::CreateTensor<Ort::Float16_t>(
364+
pImplOrt->memoryInfo,
365+
reinterpret_cast<Ort::Float16_t*>(inputs[i].data()),
366+
inputSizePerNode[i]*inputShapesCopy[i][0],
367+
inputShapesCopy[i].data(),
368+
inputShapesCopy[i].size()));
369+
} else {
370+
input_tensors.emplace_back(
371+
Ort::Value::CreateTensor<I>(
372+
pImplOrt->memoryInfo,
373+
inputs[i].data(),
374+
inputSizePerNode[i]*inputShapesCopy[i][0],
375+
inputShapesCopy[i].data(),
376+
inputShapesCopy[i].size()));
377+
}
378+
}
379+
380+
int32_t totalOutputSize = mOutputsTotal*inputShapesCopy[0][0];
381+
382+
// === Run inference ===
383+
auto output_tensors = pImplOrt->session->Run(
384+
pImplOrt->runOptions,
385+
inputNamesChar.data(),
386+
input_tensors.data(),
387+
input_tensors.size(),
388+
outputNamesChar.data(),
389+
outputNamesChar.size());
390+
391+
// === Extract output values ===
392+
O* output_data = output_tensors[0].template GetTensorMutableData<O>();
393+
std::vector<O> output_vec(output_data, output_data + totalOutputSize);
394+
output_tensors.clear();
395+
return output_vec;
283396
}
284397

398+
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
399+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
400+
285401
// private
286402
std::string OrtModel::printShape(const std::vector<int64_t>& v)
287403
{
@@ -293,6 +409,19 @@ std::string OrtModel::printShape(const std::vector<int64_t>& v)
293409
return ss.str();
294410
}
295411

412+
std::string OrtModel::printShape(const std::vector<std::vector<int64_t>>& v, std::vector<std::string>& n)
413+
{
414+
std::stringstream ss("");
415+
for (size_t i = 0; i < v.size(); i++) {
416+
ss << n[i] << " -> (";
417+
for (size_t j = 0; j < v[i].size() - 1; j++) {
418+
ss << v[i][j] << "x";
419+
}
420+
ss << v[i][v[i].size() - 1] << "); ";
421+
}
422+
return ss.str();
423+
}
424+
296425
} // namespace ml
297426

298427
} // namespace o2

0 commit comments

Comments
 (0)