Skip to content

Commit 9ff31cc

Browse files
committed
Adding Davids patch
1 parent cefe787 commit 9ff31cc

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,10 @@ class OrtModel
4545

4646
public:
4747
// Constructors & destructors
48-
OrtModel() = default;
49-
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { init(optionsMap); }
50-
void init(std::unordered_map<std::string, std::string> optionsMap)
51-
{
52-
initOptions(optionsMap);
53-
initEnvironment();
54-
}
55-
virtual ~OrtModel() = default;
48+
OrtModel();
49+
OrtModel(std::unordered_map<std::string, std::string> optionsMap);
50+
void init(std::unordered_map<std::string, std::string> optionsMap);
51+
virtual ~OrtModel();
5652

5753
// General purpose
5854
void initOptions(std::unordered_map<std::string, std::string> optionsMap);
@@ -113,7 +109,7 @@ class OrtModel
113109
private:
114110
// ORT variables -> need to be hidden as pImpl
115111
struct OrtVariables;
116-
std::shared_ptr<OrtVariables> mPImplOrt = nullptr;
112+
std::unique_ptr<OrtVariables> mPImplOrt;
117113

118114
// Input & Output specifications of the loaded network
119115
std::vector<const char*> mInputNamesChar, mOutputNamesChar;

Common/ML/src/OrtInterface.cxx

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,20 @@ namespace o2
2727
namespace ml
2828
{
2929

30+
OrtModel::OrtModel() = default;
31+
OrtModel::OrtModel(std::unordered_map<std::string, std::string> optionsMap) { init(optionsMap); }
32+
OrtModel::~OrtModel() = default;
33+
void OrtModel::init(std::unordered_map<std::string, std::string> optionsMap)
34+
{
35+
initOptions(optionsMap);
36+
initEnvironment();
37+
}
38+
3039
struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file
3140
// ORT runtime objects
3241
Ort::RunOptions runOptions;
33-
std::shared_ptr<Ort::Env> env = nullptr;
34-
std::shared_ptr<Ort::Session> session = nullptr; ///< ONNX session
42+
std::unique_ptr<Ort::Env> env = nullptr;
43+
std::unique_ptr<Ort::Session> session = nullptr; ///< ONNX session
3544
Ort::SessionOptions sessionOptions;
3645
Ort::AllocatorWithDefaultOptions allocator;
3746
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
@@ -41,7 +50,7 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
4150
// General purpose
4251
void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsMap)
4352
{
44-
mPImplOrt = std::make_shared<OrtVariables>();
53+
mPImplOrt = std::make_unique<OrtVariables>();
4554

4655
// Load from options map
4756
if (!optionsMap.contains("model-path")) {
@@ -101,7 +110,7 @@ void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsM
101110

102111
void OrtModel::initEnvironment()
103112
{
104-
mPImplOrt->env = std::make_shared<Ort::Env>(
113+
mPImplOrt->env = std::make_unique<Ort::Env>(
105114
OrtLoggingLevel(mLoggingLevel),
106115
(mEnvName.empty() ? "ORT" : mEnvName.c_str()),
107116
// Integrate ORT logging into Fairlogger
@@ -129,7 +138,7 @@ void OrtModel::initSession()
129138
if (mAllocateDeviceMemory) {
130139
memoryOnDevice(mDeviceId);
131140
}
132-
mPImplOrt->session = std::make_shared<Ort::Session>(*mPImplOrt->env, mModelPath.c_str(), mPImplOrt->sessionOptions);
141+
mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env, mModelPath.c_str(), mPImplOrt->sessionOptions);
133142
mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
134143

135144
setIO();
@@ -152,7 +161,7 @@ void OrtModel::memoryOnDevice(int32_t deviceIndex)
152161

153162
std::string dev_mem_str = "";
154163
if (mDeviceType == "ROCM") {
155-
dev_mem_str = "Hip";
164+
dev_mem_str = "HipPinned";
156165
}
157166
if (mDeviceType == "CUDA") {
158167
dev_mem_str = "Cuda";
@@ -166,7 +175,7 @@ void OrtModel::memoryOnDevice(int32_t deviceIndex)
166175

167176
void OrtModel::resetSession()
168177
{
169-
mPImplOrt->session = std::make_shared<Ort::Session>(*(mPImplOrt->env), mModelPath.c_str(), mPImplOrt->sessionOptions);
178+
mPImplOrt->session = std::make_unique<Ort::Session>(*(mPImplOrt->env), mModelPath.c_str(), mPImplOrt->sessionOptions);
170179
}
171180

172181
// Getters
@@ -252,7 +261,7 @@ void OrtModel::setIO()
252261

253262
void OrtModel::setEnv(Ort::Env* env)
254263
{
255-
mPImplOrt->env = std::shared_ptr<Ort::Env>(env);
264+
mPImplOrt->env.reset(env);
256265
}
257266

258267
// Inference

0 commit comments

Comments
 (0)