Skip to content

Commit b742c50

Browse files
committed
Adjusting eval sizes. Makes code neater and avoids some calculations
1 parent 5be779c commit b742c50

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

Common/ML/src/OrtInterface.cxx

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,19 +226,18 @@ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Fl
226226
template <class I, class O>
227227
void OrtModel::inference(I* input, size_t input_size, O* output)
228228
{
229-
std::vector<int64_t> inputShape{(int64_t)(input_size / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
229+
std::vector<int64_t> inputShape{input_size, (int64_t)mInputShapes[0][1]};
230230
Ort::Value inputTensor = Ort::Value(nullptr);
231231
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
232-
inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input), input_size, inputShape.data(), inputShape.size());
232+
inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input), input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
233233
} else {
234-
inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input, input_size, inputShape.data(), inputShape.size());
234+
inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input, input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
235235
}
236236

237-
std::vector<int64_t> outputShape{inputShape[0], mOutputShapes[0][1]};
238-
size_t outputSize = (int64_t)(input_size * mOutputShapes[0][1] / mInputShapes[0][1]);
239-
Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, output, outputSize, outputShape.data(), outputShape.size());
237+
std::vector<int64_t> outputShape{input_size, mOutputShapes[0][1]};
238+
Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, output, input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
240239

241-
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is correct here
240+
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is always correct here
242241
}
243242

244243
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*);

GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ GPUTPCNNClusterizerHost::GPUTPCNNClusterizerHost(const GPUSettingsProcessingNNcl
5757
}
5858
}
5959

60-
void GPUTPCNNClusterizerHost::networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clusterer, size_t size, float* output, int32_t dtype)
60+
void GPUTPCNNClusterizerHost::networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clustererNN, size_t size, float* output, int32_t dtype)
6161
{
6262
if (dtype == 0) {
63-
model.inference<OrtDataType::Float16_t, float>(clusterer.inputData16, size * clusterer.nnClusterizerElementSize, output);
63+
model.inference<OrtDataType::Float16_t, float>(clustererNN.inputData16, size, output);
6464
} else {
65-
model.inference<float, float>(clusterer.inputData32, size * clusterer.nnClusterizerElementSize, output);
65+
model.inference<float, float>(clustererNN.inputData32, size, output);
6666
}
6767
}

0 commit comments

Comments
 (0)