@@ -35,6 +35,16 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
3535 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo(" Cpu" , OrtAllocatorType::OrtDeviceAllocator, 0 , OrtMemType::OrtMemTypeDefault);
3636};
3737
38+ Ort::SessionOptions* OrtModel::updateSessionOptions ()
39+ {
40+ return &(pImplOrt->sessionOptions );
41+ }
42+
43+ Ort::MemoryInfo* OrtModel::updateMemoryInfo ()
44+ {
45+ return &(pImplOrt->memoryInfo );
46+ }
47+
3848void OrtModel::reset (std::unordered_map<std::string, std::string> optionsMap)
3949{
4050
@@ -56,39 +66,41 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5666 enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
5767 enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
5868
59- std::string dev_mem_str = " Hip" ;
60- #if defined(ORT_ROCM_BUILD)
61- #if ORT_ROCM_BUILD == 1
62- if (device == " ROCM" ) {
63- // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, streamId));
64- o2::gpu::SetONNXGPUStream (pImplOrt->sessionOptions , streamId);
65- LOG (info) << " (ORT) ROCM execution provider set" ;
66- }
67- #endif
68- #endif
69- #if defined(ORT_MIGRAPHX_BUILD)
70- #if ORT_MIGRAPHX_BUILD == 1
71- if (device == " MIGRAPHX" ) {
72- Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_MIGraphX (pImplOrt->sessionOptions , streamId));
73- LOG (info) << " (ORT) MIGraphX execution provider set" ;
74- }
75- #endif
76- #endif
77- #if defined(ORT_CUDA_BUILD)
78- #if ORT_CUDA_BUILD == 1
79- if (device == " CUDA" ) {
80- // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, streamId));
81- o2::gpu::SetONNXGPUStream (pImplOrt->sessionOptions , streamId);
82- LOG (info) << " (ORT) CUDA execution provider set" ;
83- dev_mem_str = " Cuda" ;
84- }
85- #endif
86- #endif
87-
69+ // #if defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1
70+ // if (device == "ROCM") {
71+ // // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, streamId));
72+ // SetONNXGPUStream(pImplOrt->sessionOptions, streamId);
73+ // LOG(info) << "(ORT) ROCM execution provider set";
74+ // }
75+ // #endif
76+ // #if defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1
77+ // if (device == "MIGRAPHX") {
78+ // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, streamId));
79+ // LOG(info) << "(ORT) MIGraphX execution provider set";
80+ // }
81+ // #endif
82+ // #if defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1
83+ // if (device == "CUDA") {
84+ // // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, streamId));
85+ // SetONNXGPUStream(pImplOrt->sessionOptions, streamId);
86+ // LOG(info) << "(ORT) CUDA execution provider set";
87+ // dev_mem_str = "Cuda";
88+ // }
89+ // #endif
90+
91+ #if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
8892 if (allocateDeviceMemory) {
93+ std::string dev_mem_str = " " ;
94+ if (device == " ROCM" ) {
95+ dev_mem_str = " Hip" ;
96+ }
97+ if (device == " CUDA" ) {
98+ dev_mem_str = " Cuda" ;
99+ }
89100 pImplOrt->memoryInfo = Ort::MemoryInfo (dev_mem_str.c_str (), OrtAllocatorType::OrtDeviceAllocator, streamId, OrtMemType::OrtMemTypeDefault);
90101 LOG (info) << " (ORT) Memory info set to on-device memory" ;
91102 }
103+ #endif
92104
93105 if (device == " CPU" ) {
94106 (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
0 commit comments