Skip to content

Commit 4faaa4a

Browse files
committed
Improve readability and adapt for some comments
1 parent 4ef35fc commit 4faaa4a

File tree

9 files changed

+197
-233
lines changed

9 files changed

+197
-233
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,54 +41,51 @@ class OrtModel
4141
{
4242

4343
public:
44-
// Constructor
44+
// Constructors & destructors
4545
OrtModel() = default;
46-
OrtModel(std::unordered_map<std::string, std::string> optionsMap) {
47-
initOptions(optionsMap);
48-
initEnvironment();
49-
}
46+
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { init(optionsMap); }
5047
void init(std::unordered_map<std::string, std::string> optionsMap) {
5148
initOptions(optionsMap);
5249
initEnvironment();
5350
}
51+
virtual ~OrtModel() = default;
52+
53+
// General purpose
5454
void initOptions(std::unordered_map<std::string, std::string> optionsMap);
5555
void initEnvironment();
56+
void memoryOnDevice(int32_t = 0);
5657
bool isInitialized() { return mInitialized; }
57-
Ort::SessionOptions& updateSessionOptions();
58-
Ort::MemoryInfo& updateMemoryInfo();
59-
void setIO();
58+
void resetSession();
6059

61-
virtual ~OrtModel() = default;
60+
// Getters
61+
std::vector<std::vector<int64_t>> getNumInputNodes() const { return mInputShapes; }
62+
std::vector<std::vector<int64_t>> getNumOutputNodes() const { return mOutputShapes; }
63+
std::vector<std::string> getInputNames() const { return mInputNames; }
64+
std::vector<std::string> getOutputNames() const { return mOutputNames; }
65+
Ort::SessionOptions& getSessionOptions();
66+
Ort::MemoryInfo& getMemoryInfo();
67+
68+
// Setters
69+
void setDeviceId(int32_t id) { deviceId = id; }
70+
void setIO();
71+
void setActiveThreads(int threads) { intraOpNumThreads = threads; }
6272

6373
// Conversion
6474
template <class I, class O>
6575
std::vector<O> v2v(std::vector<I>&, bool = true);
6676

6777
// Inferencing
6878
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
69-
std::vector<O> inference(std::vector<I>&, int32_t = -1);
70-
71-
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
72-
std::vector<O> inference(std::vector<std::vector<I>>&, int32_t = -1);
73-
74-
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
75-
void inference(I*, size_t, O*, int32_t = -1);
76-
77-
// template<class I, class T, class O> // class I is the input data type, e.g. float, class T the throughput data type and class O is the output data type
78-
// std::vector<O> inference(std::vector<I>&);
79-
80-
// Reset session
81-
void resetSession();
79+
std::vector<O> inference(std::vector<I>&);
8280

83-
std::vector<std::vector<int64_t>> getNumInputNodes() const { return mInputShapes; }
84-
std::vector<std::vector<int64_t>> getNumOutputNodes() const { return mOutputShapes; }
85-
std::vector<std::string> getInputNames() const { return mInputNames; }
86-
std::vector<std::string> getOutputNames() const { return mOutputNames; }
81+
template <class I, class O>
82+
std::vector<O> inference(std::vector<std::vector<I>>&);
8783

88-
void setActiveThreads(int threads) { intraOpNumThreads = threads; }
84+
template <class I, class O>
85+
void inference(I*, size_t, O*);
8986

9087
private:
91-
// ORT variables -> need to be hidden as Pimpl
88+
// ORT variables -> need to be hidden as pImpl
9289
struct OrtVariables;
9390
OrtVariables* pImplOrt;
9491

@@ -99,8 +96,8 @@ class OrtModel
9996

10097
// Environment settings
10198
bool mInitialized = false;
102-
std::string modelPath, envName = "", device = "cpu", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
103-
int intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
99+
std::string modelPath, envName = "", deviceType = "CPU", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
100+
int32_t intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = -1, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
104101

105102
std::string printShape(const std::vector<int64_t>&);
106103
};

Common/ML/src/OrtInterface.cxx

Lines changed: 79 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,7 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
3535
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
3636
};
3737

38-
Ort::SessionOptions& OrtModel::updateSessionOptions()
39-
{
40-
return pImplOrt->sessionOptions;
41-
}
42-
43-
Ort::MemoryInfo& OrtModel::updateMemoryInfo()
44-
{
45-
return pImplOrt->memoryInfo;
46-
}
47-
38+
// General purpose
4839
void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsMap)
4940
{
5041
pImplOrt = new OrtVariables();
@@ -56,7 +47,8 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
5647

5748
if (!optionsMap["model-path"].empty()) {
5849
modelPath = optionsMap["model-path"];
59-
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
50+
deviceType = (optionsMap.contains("device-type") ? optionsMap["device-type"] : "CPU");
51+
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : -1);
6052
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
6153
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
6254
interOpNumThreads = (optionsMap.contains("inter-op-num-threads") ? std::stoi(optionsMap["inter-op-num-threads"]) : 0);
@@ -65,7 +57,7 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
6557
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
6658
envName = (optionsMap.contains("onnx-environment-name") ? optionsMap["onnx-environment-name"] : "onnx_model_inference");
6759

68-
if (device == "CPU") {
60+
if (deviceType == "CPU") {
6961
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
7062
(pImplOrt->sessionOptions).SetInterOpNumThreads(interOpNumThreads);
7163
if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
@@ -97,14 +89,18 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
9789

9890
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
9991
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
92+
93+
mInitialized = true;
10094
} else {
10195
LOG(fatal) << "(ORT) Model path cannot be empty!";
10296
}
10397
}
10498

10599
void OrtModel::initEnvironment()
106100
{
107-
mInitialized = true;
101+
if(allocateDeviceMemory) {
102+
memoryOnDevice(deviceId);
103+
}
108104
pImplOrt->env = std::make_shared<Ort::Env>(
109105
OrtLoggingLevel(loggingLevel),
110106
(envName.empty() ? "ORT" : envName.c_str()),
@@ -128,39 +124,48 @@ void OrtModel::initEnvironment()
128124
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
129125
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
130126

127+
if (loggingLevel < 2) {
128+
LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
129+
}
130+
131131
setIO();
132132
}
133133

134-
void OrtModel::setIO() {
135-
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
136-
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
137-
}
138-
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
139-
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
140-
}
141-
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
142-
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
143-
}
144-
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
145-
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
146-
}
147-
148-
inputNamesChar.resize(mInputNames.size(), nullptr);
149-
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
150-
[&](const std::string& str) { return str.c_str(); });
151-
outputNamesChar.resize(mOutputNames.size(), nullptr);
152-
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
153-
[&](const std::string& str) { return str.c_str(); });
154-
if (loggingLevel < 2) {
155-
LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
134+
void OrtModel::memoryOnDevice(int32_t deviceIndex)
135+
{
136+
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
137+
if (deviceIndex >= 0) {
138+
std::string dev_mem_str = "";
139+
if (deviceType == "ROCM") {
140+
dev_mem_str = "Hip";
141+
}
142+
if (deviceType == "CUDA") {
143+
dev_mem_str = "Cuda";
144+
}
145+
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
146+
if (loggingLevel < 2) {
147+
LOG(info) << "(ORT) Memory info set to on-device memory for device type " << deviceType << " with ID " << deviceIndex;
148+
}
156149
}
150+
#endif
157151
}
158152

159153
void OrtModel::resetSession()
160154
{
161155
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
162156
}
163157

158+
// Getters
159+
Ort::SessionOptions& OrtModel::getSessionOptions()
160+
{
161+
return pImplOrt->sessionOptions;
162+
}
163+
164+
Ort::MemoryInfo& OrtModel::getMemoryInfo()
165+
{
166+
return pImplOrt->memoryInfo;
167+
}
168+
164169
template <class I, class O>
165170
std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
166171
{
@@ -176,32 +181,32 @@ std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
176181
}
177182
}
178183

179-
std::string OrtModel::printShape(const std::vector<int64_t>& v)
180-
{
181-
std::stringstream ss("");
182-
for (size_t i = 0; i < v.size() - 1; i++) {
183-
ss << v[i] << "x";
184+
void OrtModel::setIO() {
185+
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
186+
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
184187
}
185-
ss << v[v.size() - 1];
186-
return ss.str();
188+
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
189+
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
190+
}
191+
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
192+
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
193+
}
194+
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
195+
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
196+
}
197+
198+
inputNamesChar.resize(mInputNames.size(), nullptr);
199+
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
200+
[&](const std::string& str) { return str.c_str(); });
201+
outputNamesChar.resize(mOutputNames.size(), nullptr);
202+
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
203+
[&](const std::string& str) { return str.c_str(); });
187204
}
188205

206+
// Inference
189207
template <class I, class O>
190-
std::vector<O> OrtModel::inference(std::vector<I>& input, int32_t deviceIndex)
208+
std::vector<O> OrtModel::inference(std::vector<I>& input)
191209
{
192-
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
193-
if (allocateDeviceMemory) {
194-
std::string dev_mem_str = "";
195-
if (device == "ROCM") {
196-
dev_mem_str = "Hip";
197-
}
198-
if (device == "CUDA") {
199-
dev_mem_str = "Cuda";
200-
}
201-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
202-
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID "<< deviceIndex;
203-
}
204-
#endif
205210
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
206211
std::vector<Ort::Value> inputTensor;
207212
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
@@ -217,32 +222,19 @@ std::vector<O> OrtModel::inference(std::vector<I>& input, int32_t deviceIndex)
217222
return outputValuesVec;
218223
}
219224

220-
template std::vector<float> OrtModel::inference<float, float>(std::vector<float>&, int32_t);
225+
template std::vector<float> OrtModel::inference<float, float>(std::vector<float>&);
221226

222-
template std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&, int32_t);
227+
template std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
223228

224-
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&, int32_t);
229+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
225230

226231
template <class I, class O>
227-
void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceIndex)
232+
void OrtModel::inference(I* input, size_t input_size, O* output)
228233
{
229234
// std::vector<std::string> providers = Ort::GetAvailableProviders();
230235
// for (const auto& provider : providers) {
231236
// LOG(info) << "Available Execution Provider: " << provider;
232237
// }
233-
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
234-
if (allocateDeviceMemory) {
235-
std::string dev_mem_str = "";
236-
if (device == "ROCM") {
237-
dev_mem_str = "Hip";
238-
}
239-
if (device == "CUDA") {
240-
dev_mem_str = "Cuda";
241-
}
242-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
243-
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID "<< deviceIndex;
244-
}
245-
#endif
246238
std::vector<int64_t> inputShape{input_size, (int64_t)mInputShapes[0][1]};
247239
Ort::Value inputTensor = Ort::Value(nullptr);
248240
if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
@@ -257,26 +249,13 @@ void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceI
257249
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size());
258250
}
259251

260-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*, int32_t);
252+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*);
261253

262-
template void OrtModel::inference<float, float>(float*, size_t, float*, int32_t);
254+
template void OrtModel::inference<float, float>(float*, size_t, float*);
263255

264256
template <class I, class O>
265-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input, int32_t deviceIndex)
257+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
266258
{
267-
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
268-
if (allocateDeviceMemory) {
269-
std::string dev_mem_str = "";
270-
if (device == "ROCM") {
271-
dev_mem_str = "Hip";
272-
}
273-
if (device == "CUDA") {
274-
dev_mem_str = "Cuda";
275-
}
276-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
277-
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex;
278-
}
279-
#endif
280259
std::vector<Ort::Value> inputTensor;
281260
for (auto i : input) {
282261
std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
@@ -294,6 +273,17 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input, int32_t d
294273
return outputValuesVec;
295274
}
296275

276+
// private
277+
std::string OrtModel::printShape(const std::vector<int64_t>& v)
278+
{
279+
std::stringstream ss("");
280+
for (size_t i = 0; i < v.size() - 1; i++) {
281+
ss << v[i] << "x";
282+
}
283+
ss << v[v.size() - 1];
284+
return ss.str();
285+
}
286+
297287
} // namespace ml
298288

299289
} // namespace o2

GPU/GPUTracking/Base/GPUReconstructionProcessing.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class GPUReconstructionProcessing : public GPUReconstruction
9292
void AddGPUEvents(T*& events);
9393

9494
virtual std::unique_ptr<gpu_reconstruction_kernels::threadContext> GetThreadContext() override;
95-
virtual void SetONNXGPUStream(Ort::SessionOptions&, int32_t, int32_t*) {}
95+
// virtual void SetONNXGPUStream(Ort::SessionOptions&, int32_t, int32_t*) {}
9696

9797
struct RecoStepTimerMeta {
9898
HighResTimer timerToGPU;

GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions& session_option
673673
// UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size());
674674

675675
// this implicitly sets "has_user_compute_stream"
676-
UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", &mInternals->Streams[stream]);
676+
UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", mInternals->Streams[stream]);
677677
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
678678

679679
// Finally, don't forget to release the provider options
@@ -694,7 +694,7 @@ void GPUReconstructionHIP::SetONNXGPUStream(Ort::SessionOptions& session_options
694694
{
695695
// Create ROCm provider options
696696
cudaGetDevice(deviceId);
697-
const auto& api = Ort::GetApi();
697+
// const auto& api = Ort::GetApi();
698698
// api.GetCurrentGpuDeviceId(deviceId);
699699
OrtROCMProviderOptions rocm_options;
700700
rocm_options.has_user_compute_stream = 1; // Indicate that we are passing a user stream

0 commit comments

Comments
 (0)