|
| 1 | +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. |
| 2 | +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. |
| 3 | +// All rights not expressly granted are reserved. |
| 4 | +// |
| 5 | +// This software is distributed under the terms of the GNU General Public |
| 6 | +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". |
| 7 | +// |
| 8 | +// In applying this license CERN does not waive the privileges and immunities |
| 9 | +// granted to it by virtue of its status as an Intergovernmental Organization |
| 10 | +// or submit itself to any jurisdiction. |
| 11 | + |
| 12 | +/// \file ort_interface.cxx |
| 13 | +/// \author Christian Sonnabend <christian.sonnabend@cern.ch> |
| 14 | +/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU |
| 15 | + |
| 16 | +#include "ML/ort_interface.h" |
| 17 | +#include "ML/3rdparty/GPUORTFloat16.h" |
| 18 | + |
| 19 | +// ONNX includes |
| 20 | +#include <onnxruntime_cxx_api.h> |
| 21 | + |
| 22 | +namespace o2 |
| 23 | +{ |
| 24 | + |
| 25 | +namespace ml |
| 26 | +{ |
| 27 | + |
| 28 | +struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file |
| 29 | + // ORT runtime objects |
| 30 | + Ort::RunOptions runOptions; |
| 31 | + std::shared_ptr<Ort::Env> env = nullptr; |
| 32 | + std::shared_ptr<Ort::Session> session = nullptr; ///< ONNX session |
| 33 | + Ort::SessionOptions sessionOptions; |
| 34 | + Ort::AllocatorWithDefaultOptions allocator; |
| 35 | + Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); |
| 36 | +}; |
| 37 | + |
| 38 | +void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap) |
| 39 | +{ |
| 40 | + |
| 41 | + pImplOrt = new OrtVariables(); |
| 42 | + |
| 43 | + // Load from options map |
| 44 | + if (!optionsMap.contains("model-path")) { |
| 45 | + LOG(fatal) << "(ORT) Model path cannot be empty!"; |
| 46 | + } |
| 47 | + modelPath = optionsMap["model-path"]; |
| 48 | + device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU"); |
| 49 | + dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float"); |
| 50 | + deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0); |
| 51 | + allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0); |
| 52 | + intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0); |
| 53 | + loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0); |
| 54 | + enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0); |
| 55 | + enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0); |
| 56 | + |
| 57 | + std::string dev_mem_str = "Hip"; |
| 58 | +#ifdef ORT_ROCM_BUILD |
| 59 | + if (device == "ROCM") { |
| 60 | + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); |
| 61 | + LOG(info) << "(ORT) ROCM execution provider set"; |
| 62 | + } |
| 63 | +#endif |
| 64 | +#ifdef ORT_MIGRAPHX_BUILD |
| 65 | + if (device == "MIGRAPHX") { |
| 66 | + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); |
| 67 | + LOG(info) << "(ORT) MIGraphX execution provider set"; |
| 68 | + } |
| 69 | +#endif |
| 70 | +#ifdef ORT_CUDA_BUILD |
| 71 | + if (device == "CUDA") { |
| 72 | + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); |
| 73 | + LOG(info) << "(ORT) CUDA execution provider set"; |
| 74 | + dev_mem_str = "Cuda"; |
| 75 | + } |
| 76 | +#endif |
| 77 | + |
| 78 | + if (allocateDeviceMemory) { |
| 79 | + pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault); |
| 80 | + LOG(info) << "(ORT) Memory info set to on-device memory"; |
| 81 | + } |
| 82 | + |
| 83 | + if (device == "CPU") { |
| 84 | + (pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads); |
| 85 | + if (intraOpNumThreads > 1) { |
| 86 | + (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL); |
| 87 | + } else if (intraOpNumThreads == 1) { |
| 88 | + (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); |
| 89 | + } |
| 90 | + LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads"; |
| 91 | + } |
| 92 | + |
| 93 | + (pImplOrt->sessionOptions).DisableMemPattern(); |
| 94 | + (pImplOrt->sessionOptions).DisableCpuMemArena(); |
| 95 | + |
| 96 | + if (enableProfiling) { |
| 97 | + if (optionsMap.contains("profiling-output-path")) { |
| 98 | + (pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str()); |
| 99 | + } else { |
| 100 | + LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now."; |
| 101 | + (pImplOrt->sessionOptions).DisableProfiling(); |
| 102 | + } |
| 103 | + } else { |
| 104 | + (pImplOrt->sessionOptions).DisableProfiling(); |
| 105 | + } |
| 106 | + (pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations)); |
| 107 | + (pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel)); |
| 108 | + |
| 109 | + pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str())); |
| 110 | + pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); |
| 111 | + |
| 112 | + for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { |
| 113 | + mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get()); |
| 114 | + } |
| 115 | + for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { |
| 116 | + mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); |
| 117 | + } |
| 118 | + for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) { |
| 119 | + mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get()); |
| 120 | + } |
| 121 | + for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) { |
| 122 | + mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); |
| 123 | + } |
| 124 | + |
| 125 | + inputNamesChar.resize(mInputNames.size(), nullptr); |
| 126 | + std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar), |
| 127 | + [&](const std::string& str) { return str.c_str(); }); |
| 128 | + outputNamesChar.resize(mOutputNames.size(), nullptr); |
| 129 | + std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar), |
| 130 | + [&](const std::string& str) { return str.c_str(); }); |
| 131 | + |
| 132 | + // Print names |
| 133 | + if (loggingLevel > 1) { |
| 134 | + LOG(info) << "Input Nodes:"; |
| 135 | + for (size_t i = 0; i < mInputNames.size(); i++) { |
| 136 | + LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); |
| 137 | + } |
| 138 | + |
| 139 | + LOG(info) << "Output Nodes:"; |
| 140 | + for (size_t i = 0; i < mOutputNames.size(); i++) { |
| 141 | + LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); |
| 142 | + } |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +void OrtModel::resetSession() |
| 147 | +{ |
| 148 | + pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); |
| 149 | +} |
| 150 | + |
| 151 | +template <class I, class O> |
| 152 | +std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput) |
| 153 | +{ |
| 154 | + if constexpr (std::is_same_v<I, O>) { |
| 155 | + return input; |
| 156 | + } else { |
| 157 | + std::vector<O> output(input.size()); |
| 158 | + std::transform(std::begin(input), std::end(input), std::begin(output), [](I f) { return O(f); }); |
| 159 | + if (clearInput) { |
| 160 | + input.clear(); |
| 161 | + } |
| 162 | + return output; |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h |
| 167 | +std::vector<O> OrtModel::inference(std::vector<I>& input) |
| 168 | +{ |
| 169 | + std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 170 | + std::vector<Ort::Value> inputTensor; |
| 171 | + inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(input.data()), input.size(), inputShape.data(), inputShape.size())); |
| 172 | + // input.clear(); |
| 173 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 174 | + O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>()); |
| 175 | + std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; |
| 176 | + outputTensors.clear(); |
| 177 | + return outputValuesVec; |
| 178 | +} |
| 179 | + |
| 180 | +template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h |
| 181 | +std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input) |
| 182 | +{ |
| 183 | + std::vector<Ort::Value> inputTensor; |
| 184 | + for (auto i : input) { |
| 185 | + std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 186 | + inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(i.data()), i.size(), inputShape.data(), inputShape.size())); |
| 187 | + } |
| 188 | + // input.clear(); |
| 189 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 190 | + O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>()); |
| 191 | + std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]}; |
| 192 | + outputTensors.clear(); |
| 193 | + return outputValuesVec; |
| 194 | +} |
| 195 | + |
| 196 | +std::string OrtModel::printShape(const std::vector<int64_t>& v) |
| 197 | +{ |
| 198 | + std::stringstream ss(""); |
| 199 | + for (size_t i = 0; i < v.size() - 1; i++) { |
| 200 | + ss << v[i] << "x"; |
| 201 | + } |
| 202 | + ss << v[v.size() - 1]; |
| 203 | + return ss.str(); |
| 204 | +} |
| 205 | + |
| 206 | +template <> |
| 207 | +std::vector<float> OrtModel::inference<float, float>(std::vector<float>& input) |
| 208 | +{ |
| 209 | + std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 210 | + std::vector<Ort::Value> inputTensor; |
| 211 | + inputTensor.emplace_back(Ort::Value::CreateTensor<float>(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size())); |
| 212 | + // input.clear(); |
| 213 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 214 | + float* outputValues = outputTensors[0].template GetTensorMutableData<float>(); |
| 215 | + std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; |
| 216 | + outputTensors.clear(); |
| 217 | + return outputValuesVec; |
| 218 | +} |
| 219 | + |
| 220 | +template <> |
| 221 | +std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>& input) |
| 222 | +{ |
| 223 | + std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 224 | + std::vector<Ort::Value> inputTensor; |
| 225 | + inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size())); |
| 226 | + // input.clear(); |
| 227 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 228 | + float* outputValues = outputTensors[0].template GetTensorMutableData<float>(); |
| 229 | + std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; |
| 230 | + outputTensors.clear(); |
| 231 | + return outputValuesVec; |
| 232 | +} |
| 233 | + |
| 234 | +template <> |
| 235 | +std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>& input) |
| 236 | +{ |
| 237 | + std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 238 | + std::vector<Ort::Value> inputTensor; |
| 239 | + inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size())); |
| 240 | + // input.clear(); |
| 241 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 242 | + OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>()); |
| 243 | + std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; |
| 244 | + outputTensors.clear(); |
| 245 | + return outputValuesVec; |
| 246 | +} |
| 247 | + |
| 248 | +template <> |
| 249 | +std::vector<OrtDataType::Float16_t> OrtModel::inference<float, OrtDataType::Float16_t>(std::vector<float>& input) |
| 250 | +{ |
| 251 | + std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 252 | + std::vector<Ort::Value> inputTensor; |
| 253 | + inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size())); |
| 254 | + // input.clear(); |
| 255 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 256 | + OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>()); |
| 257 | + std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; |
| 258 | + outputTensors.clear(); |
| 259 | + return outputValuesVec; |
| 260 | +} |
| 261 | + |
| 262 | +template <> |
| 263 | +std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>& input) |
| 264 | +{ |
| 265 | + std::vector<Ort::Value> inputTensor; |
| 266 | + for (auto i : input) { |
| 267 | + std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; |
| 268 | + inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size())); |
| 269 | + } |
| 270 | + // input.clear(); |
| 271 | + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); |
| 272 | + OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>()); |
| 273 | + std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]}; |
| 274 | + outputTensors.clear(); |
| 275 | + return outputValuesVec; |
| 276 | +} |
| 277 | + |
| 278 | +} // namespace ml |
| 279 | + |
| 280 | +} // namespace o2 |
0 commit comments