Skip to content

Commit 9a71e7f

Browse files
authored
Additional changes for newer versions of ONNXRuntime (#13072)
* Additional changes for newer versions of ONNXRuntime * Clang-format
1 parent 1570db6 commit 9a71e7f

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

Detectors/TRD/pid/include/TRDPID/ML.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@ class ML : public PIDBase
7070
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
7171
LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]");
7272
},
73-
(void*)3}; ///< ONNX enviroment
74-
const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api
73+
(void*)3}; ///< ONNX enviroment
74+
const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api
75+
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
7576
std::unique_ptr<Ort::Experimental::Session> mSession; ///< ONNX session
76-
Ort::SessionOptions mSessionOptions; ///< ONNX session options
77+
#else
78+
std::unique_ptr<Ort::Session> mSession; ///< ONNX session
79+
#endif
80+
Ort::SessionOptions mSessionOptions; ///< ONNX session options
7781
Ort::AllocatorWithDefaultOptions mAllocator;
7882

7983
// Input/Output

Detectors/TRD/pid/src/ML.cxx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ void ML::init(o2::framework::ProcessingContext& pc)
6868
LOG(info) << "Set GraphOptimizationLevel to " << mParams.graphOptimizationLevel;
6969

7070
// create actual session
71+
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
7172
mSession = std::make_unique<Ort::Experimental::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
73+
#else
74+
mSession = std::make_unique<Ort::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
75+
#endif
7276
LOG(info) << "ONNX runtime session created";
7377

7478
// print name/shape of inputs
@@ -104,8 +108,15 @@ float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer&
104108
try {
105109
auto input = prepareModelInput(trk, inputTracks);
106110
// create memory mapping to vector above
111+
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
107112
auto inputTensor = Ort::Experimental::Value::CreateTensor<float>(input.data(), input.size(),
108113
{static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
114+
#else
115+
Ort::MemoryInfo mem_info =
116+
Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
117+
auto inputTensor = Ort::Value::CreateTensor<float>(mem_info, input.data(), input.size(),
118+
{static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
119+
#endif
109120
std::vector<Ort::Value> ortTensor;
110121
ortTensor.push_back(std::move(inputTensor));
111122
auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames);

dependencies/FindONNXRuntime.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ endif()
1717

1818
if (NOT ONNXRuntime::ONNXRuntime_FOUND)
1919
find_package(onnxruntime CONFIG)
20+
if (onnxruntime_FOUND)
21+
add_library(ONNXRuntime::ONNXRuntime ALIAS onnxruntime::onnxruntime)
22+
endif()
2023
endif()

0 commit comments

Comments
 (0)