Skip to content

Commit afbdade

Browse files
committed
Making shared pointer for releasing
1 parent cd6ceb5 commit afbdade

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class OrtModel
113113
private:
114114
// ORT variables -> need to be hidden as pImpl
115115
struct OrtVariables;
116-
OrtVariables* mPImplOrt;
116+
std::shared_ptr<OrtVariables> mPImplOrt = nullptr;
117117

118118
// Input & Output specifications of the loaded network
119119
std::vector<const char*> mInputNamesChar, mOutputNamesChar;

Common/ML/src/OrtInterface.cxx

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
4141
// General purpose
4242
void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsMap)
4343
{
44-
mPImplOrt = new OrtVariables();
44+
mPImplOrt = std::make_shared<OrtVariables>();
4545

4646
// Load from options map
4747
if (!optionsMap.contains("model-path")) {
@@ -147,8 +147,8 @@ void OrtModel::memoryOnDevice(int32_t deviceIndex)
147147
(mPImplOrt->sessionOptions).AddConfigEntry("session.use_env_allocators", "1"); // This should enable to use the volatile memory allocation defined in O2/GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx; not working yet: ONNX still assigns new memory at init time
148148
(mPImplOrt->sessionOptions).AddConfigEntry("session_options.enable_cpu_mem_arena", "0"); // This should enable to use the volatile memory allocation defined in O2/GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx; not working yet: ONNX still assigns new memory at init time
149149
// Arena memory shrinkage comes at performance cost
150-
/// For now prefer to use single allocation, enabled by O2/GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu -> SetONNXGPUStream -> rocm_options.arena_extend_strategy = 0;
151-
// (mPImplOrt->runOptions).AddConfigEntry("memory.enable_memory_arena_shrinkage", ("gpu:" + std::to_string(deviceIndex)).c_str()); // See kOrtRunOptionsConfigEnableMemoryArenaShrinkage, https://github.com/microsoft/onnxruntime/blob/90c263f471bbce724e77d8e62831d3a9fa838b2f/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h#L27
150+
// For now prefer to use single allocation, enabled by O2/GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu -> SetONNXGPUStream -> rocm_options.arena_extend_strategy = 0;
151+
(mPImplOrt->runOptions).AddConfigEntry("memory.enable_memory_arena_shrinkage", ("gpu:" + std::to_string(deviceIndex)).c_str()); // See kOrtRunOptionsConfigEnableMemoryArenaShrinkage, https://github.com/microsoft/onnxruntime/blob/90c263f471bbce724e77d8e62831d3a9fa838b2f/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h#L27
152152

153153
std::string dev_mem_str = "";
154154
if (mDeviceType == "ROCM") {
@@ -308,6 +308,14 @@ void OrtModel::inference(I* input, int64_t input_size, O* output)
308308
(mPImplOrt->ioBinding)->BindOutput(mOutputNames[0].c_str(), outputTensor);
309309

310310
(mPImplOrt->session)->Run(mPImplOrt->runOptions, *mPImplOrt->ioBinding);
311+
// mPImplOrt->session->Run(
312+
// mPImplOrt->runOptions,
313+
// mInputNamesChar.data(),
314+
// &inputTensor,
315+
// mInputNamesChar.size(),
316+
// mOutputNamesChar.data(),
317+
// &outputTensor,
318+
// mOutputNamesChar.size());
311319
}
312320

313321
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
@@ -427,10 +435,7 @@ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Fl
427435
// Release session
428436
void OrtModel::release(bool profilingEnabled)
429437
{
430-
// if (profilingEnabled) {
431-
// mPImplOrt->session->EndProfiling();
432-
// }
433-
LOG(info) << "(ORT) Size of mPImplOrt: " << sizeof(*mPImplOrt) << " bytes";
438+
mPImplOrt.reset();
434439
}
435440

436441
// private

0 commit comments

Comments
 (0)