@@ -27,11 +27,20 @@ namespace o2
2727namespace ml
2828{
2929
30+ OrtModel::OrtModel () = default ;
31+ OrtModel::OrtModel (std::unordered_map<std::string, std::string> optionsMap) { init (optionsMap); }
32+ OrtModel::~OrtModel () = default ;
33+ void OrtModel::init (std::unordered_map<std::string, std::string> optionsMap)
34+ {
35+ initOptions (optionsMap);
36+ initEnvironment ();
37+ }
38+
3039struct OrtModel ::OrtVariables { // The actual implementation is hidden in the .cxx file
3140 // ORT runtime objects
3241 Ort::RunOptions runOptions;
33- std::shared_ptr <Ort::Env> env = nullptr ;
34- std::shared_ptr <Ort::Session> session = nullptr ; // /< ONNX session
42+ std::unique_ptr <Ort::Env> env = nullptr ;
43+ std::unique_ptr <Ort::Session> session = nullptr ; // /< ONNX session
3544 Ort::SessionOptions sessionOptions;
3645 Ort::AllocatorWithDefaultOptions allocator;
3746 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo(" Cpu" , OrtAllocatorType::OrtDeviceAllocator, 0 , OrtMemType::OrtMemTypeDefault);
@@ -41,7 +50,7 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
4150// General purpose
4251void OrtModel::initOptions (std::unordered_map<std::string, std::string> optionsMap)
4352{
44- mPImplOrt = std::make_shared <OrtVariables>();
53+ mPImplOrt = std::make_unique <OrtVariables>();
4554
4655 // Load from options map
4756 if (!optionsMap.contains (" model-path" )) {
@@ -101,7 +110,7 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
101110
102111void OrtModel::initEnvironment ()
103112{
104- mPImplOrt ->env = std::make_shared <Ort::Env>(
113+ mPImplOrt ->env = std::make_unique <Ort::Env>(
105114 OrtLoggingLevel (mLoggingLevel ),
106115 (mEnvName .empty () ? " ORT" : mEnvName .c_str ()),
107116 // Integrate ORT logging into Fairlogger
@@ -129,7 +138,7 @@ void OrtModel::initSession()
129138 if (mAllocateDeviceMemory ) {
130139 memoryOnDevice (mDeviceId );
131140 }
132- mPImplOrt ->session = std::make_shared <Ort::Session>(*mPImplOrt ->env , mModelPath .c_str (), mPImplOrt ->sessionOptions );
141+ mPImplOrt ->session = std::make_unique <Ort::Session>(*mPImplOrt ->env , mModelPath .c_str (), mPImplOrt ->sessionOptions );
133142 mPImplOrt ->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt ->session );
134143
135144 setIO ();
@@ -152,7 +161,7 @@ void OrtModel::memoryOnDevice(int32_t deviceIndex)
152161
153162 std::string dev_mem_str = " " ;
154163 if (mDeviceType == " ROCM" ) {
155- dev_mem_str = " Hip " ;
164+ dev_mem_str = " HipPinned " ;
156165 }
157166 if (mDeviceType == " CUDA" ) {
158167 dev_mem_str = " Cuda" ;
@@ -166,7 +175,7 @@ void OrtModel::memoryOnDevice(int32_t deviceIndex)
166175
167176void OrtModel::resetSession ()
168177{
169- mPImplOrt ->session = std::make_shared <Ort::Session>(*(mPImplOrt ->env ), mModelPath .c_str (), mPImplOrt ->sessionOptions );
178+ mPImplOrt ->session = std::make_unique <Ort::Session>(*(mPImplOrt ->env ), mModelPath .c_str (), mPImplOrt ->sessionOptions );
170179}
171180
172181// Getters
@@ -252,7 +261,7 @@ void OrtModel::setIO()
252261
253262void OrtModel::setEnv (Ort::Env* env)
254263{
255- mPImplOrt ->env = std::shared_ptr<Ort::Env> (env);
264+ mPImplOrt ->env . reset (env);
256265}
257266
258267// Inference
0 commit comments