Skip to content

Commit 007a4a1

Browse files
committed
This runs, but will eventually fill up the VRAM. Need to include a mem clean
1 parent 9d9267f commit 007a4a1

14 files changed

+242
-223
lines changed

Common/ML/CMakeLists.txt

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,10 @@
1010
# or submit itself to any jurisdiction.
1111

1212
# Pass ORT variables as a preprocessor definition
13-
if(DEFINED ENV{ORT_ROCM_BUILD})
14-
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
15-
endif()
16-
if(DEFINED ENV{ORT_CUDA_BUILD})
17-
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
18-
endif()
19-
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
20-
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
21-
endif()
22-
if(DEFINED ENV{ORT_TENSORRT_BUILD})
23-
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
24-
endif()
13+
add_compile_definitions(ORT_ROCM_BUILD=${ORT_ROCM_BUILD})
14+
add_compile_definitions(ORT_CUDA_BUILD=${ORT_CUDA_BUILD})
15+
add_compile_definitions(ORT_MIGRAPHX_BUILD=${ORT_MIGRAPHX_BUILD})
16+
add_compile_definitions(ORT_TENSORRT_BUILD=${ORT_TENSORRT_BUILD})
2517

2618
o2_add_library(ML
2719
SOURCES src/OrtInterface.cxx

Common/ML/include/ML/OrtInterface.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,19 @@ class OrtModel
4343
public:
4444
// Constructor
4545
OrtModel() = default;
46-
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
47-
void init(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
48-
void reset(std::unordered_map<std::string, std::string>);
46+
OrtModel(std::unordered_map<std::string, std::string> optionsMap) {
47+
initOptions(optionsMap);
48+
initEnvironment();
49+
}
50+
void init(std::unordered_map<std::string, std::string> optionsMap) {
51+
initOptions(optionsMap);
52+
initEnvironment();
53+
}
54+
void initOptions(std::unordered_map<std::string, std::string> optionsMap);
55+
void initEnvironment();
4956
bool isInitialized() { return mInitialized; }
50-
Ort::SessionOptions* updateSessionOptions();
51-
Ort::MemoryInfo* updateMemoryInfo();
57+
Ort::SessionOptions& updateSessionOptions();
58+
void setIO();
5259

5360
virtual ~OrtModel() = default;
5461

@@ -91,7 +98,7 @@ class OrtModel
9198

9299
// Environment settings
93100
bool mInitialized = false;
94-
std::string modelPath, device = "cpu", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
101+
std::string modelPath, envName = "", device = "cpu", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
95102
int intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
96103

97104
std::string printShape(const std::vector<int64_t>&);

Common/ML/src/OrtInterface.cxx

Lines changed: 51 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,13 @@ struct OrtModel::OrtVariables { // The actual implementation is hidden in the .c
3535
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
3636
};
3737

38-
Ort::SessionOptions* OrtModel::updateSessionOptions()
38+
Ort::SessionOptions& OrtModel::updateSessionOptions()
3939
{
40-
return &(pImplOrt->sessionOptions);
40+
return pImplOrt->sessionOptions;
4141
}
4242

43-
Ort::MemoryInfo* OrtModel::updateMemoryInfo()
43+
void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsMap)
4444
{
45-
return &(pImplOrt->memoryInfo);
46-
}
47-
48-
void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
49-
{
50-
5145
pImplOrt = new OrtVariables();
5246

5347
// Load from options map
@@ -58,71 +52,57 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5852
if (!optionsMap["model-path"].empty()) {
5953
modelPath = optionsMap["model-path"];
6054
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
61-
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
6255
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
6356
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
6457
interOpNumThreads = (optionsMap.contains("inter-op-num-threads") ? std::stoi(optionsMap["inter-op-num-threads"]) : 0);
6558
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
6659
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
6760
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
68-
69-
// #if defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1
70-
// if (device == "ROCM") {
71-
// // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
72-
// SetONNXGPUStream(pImplOrt->sessionOptions, deviceId);
73-
// LOG(info) << "(ORT) ROCM execution provider set";
74-
// }
75-
// #endif
76-
// #if defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1
77-
// if (device == "MIGRAPHX") {
78-
// Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
79-
// LOG(info) << "(ORT) MIGraphX execution provider set";
80-
// }
81-
// #endif
82-
// #if defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1
83-
// if (device == "CUDA") {
84-
// // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
85-
// SetONNXGPUStream(pImplOrt->sessionOptions, deviceId);
86-
// LOG(info) << "(ORT) CUDA execution provider set";
87-
// dev_mem_str = "Cuda";
88-
// }
89-
// #endif
90-
91-
if (device == "CPU") {
92-
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
93-
(pImplOrt->sessionOptions).SetInterOpNumThreads(interOpNumThreads);
94-
if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
95-
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
96-
} else if (intraOpNumThreads == 1) {
97-
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
98-
}
99-
if (loggingLevel < 2) {
100-
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads";
61+
envName = (optionsMap.contains("onnx-environment-name") ? optionsMap["onnx-environment-name"] : "onnx_model_inference");
62+
63+
if (device == "CPU") {
64+
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
65+
(pImplOrt->sessionOptions).SetInterOpNumThreads(interOpNumThreads);
66+
if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
67+
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
68+
} else if (intraOpNumThreads == 1) {
69+
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
70+
}
71+
if (loggingLevel < 2) {
72+
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads";
73+
}
10174
}
102-
}
10375

104-
(pImplOrt->sessionOptions).DisableMemPattern();
105-
(pImplOrt->sessionOptions).DisableCpuMemArena();
76+
// OrtROCMProviderOptions rocm_options{};
77+
// (pImplOrt->sessionOptions).AppendExecutionProvider_ROCM(rocm_options);
10678

107-
if (enableProfiling) {
108-
if (optionsMap.contains("profiling-output-path")) {
109-
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
79+
(pImplOrt->sessionOptions).DisableMemPattern();
80+
(pImplOrt->sessionOptions).DisableCpuMemArena();
81+
82+
if (enableProfiling) {
83+
if (optionsMap.contains("profiling-output-path")) {
84+
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
85+
} else {
86+
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
87+
(pImplOrt->sessionOptions).DisableProfiling();
88+
}
11089
} else {
111-
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
11290
(pImplOrt->sessionOptions).DisableProfiling();
11391
}
92+
93+
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
94+
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
11495
} else {
115-
(pImplOrt->sessionOptions).DisableProfiling();
96+
LOG(fatal) << "(ORT) Model path cannot be empty!";
11697
}
98+
}
11799

100+
void OrtModel::initEnvironment()
101+
{
118102
mInitialized = true;
119-
120-
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
121-
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
122-
123103
pImplOrt->env = std::make_shared<Ort::Env>(
124104
OrtLoggingLevel(loggingLevel),
125-
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
105+
(envName.empty() ? "ORT" : envName.c_str()),
126106
// Integrate ORT logging into Fairlogger
127107
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
128108
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
@@ -143,6 +123,10 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
143123
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
144124
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
145125

126+
setIO();
127+
}
128+
129+
void OrtModel::setIO() {
146130
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
147131
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
148132
}
@@ -162,7 +146,6 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
162146
outputNamesChar.resize(mOutputNames.size(), nullptr);
163147
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
164148
[&](const std::string& str) { return str.c_str(); });
165-
}
166149
if (loggingLevel < 2) {
167150
LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
168151
}
@@ -203,18 +186,15 @@ std::vector<O> OrtModel::inference(std::vector<I>& input, int32_t deviceIndex)
203186
{
204187
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
205188
if (allocateDeviceMemory) {
206-
if (deviceIndex >= 0) {
207-
deviceId = deviceIndex;
208-
}
209189
std::string dev_mem_str = "";
210190
if (device == "ROCM") {
211191
dev_mem_str = "Hip";
212192
}
213193
if (device == "CUDA") {
214194
dev_mem_str = "Cuda";
215195
}
216-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
217-
LOG(info) << "(ORT) Memory info set to on-device memory";
196+
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
197+
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID "<< deviceIndex;
218198
}
219199
#endif
220200
std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
@@ -241,20 +221,21 @@ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Fl
241221
template <class I, class O>
242222
void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceIndex)
243223
{
224+
// std::vector<std::string> providers = Ort::GetAvailableProviders();
225+
// for (const auto& provider : providers) {
226+
// LOG(info) << "Available Execution Provider: " << provider;
227+
// }
244228
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
245229
if (allocateDeviceMemory) {
246-
if (deviceIndex >= 0) {
247-
deviceId = deviceIndex;
248-
}
249230
std::string dev_mem_str = "";
250231
if (device == "ROCM") {
251232
dev_mem_str = "Hip";
252233
}
253234
if (device == "CUDA") {
254235
dev_mem_str = "Cuda";
255236
}
256-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
257-
LOG(info) << "(ORT) Memory info set to on-device memory";
237+
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
238+
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID "<< deviceIndex;
258239
}
259240
#endif
260241
std::vector<int64_t> inputShape{input_size, (int64_t)mInputShapes[0][1]};
@@ -268,7 +249,7 @@ void OrtModel::inference(I* input, size_t input_size, O* output, int32_t deviceI
268249
std::vector<int64_t> outputShape{input_size, mOutputShapes[0][1]};
269250
Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, output, input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
270251

271-
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is always correct here
252+
(pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size());
272253
}
273254

274255
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*, int32_t);
@@ -280,18 +261,15 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input, int32_t d
280261
{
281262
#if (defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1) || (defined(ORT_MIGRAPHX_BUILD) && ORT_MIGRAPHX_BUILD == 1) || (defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1)
282263
if (allocateDeviceMemory) {
283-
if (deviceIndex >= 0) {
284-
deviceId = deviceIndex;
285-
}
286264
std::string dev_mem_str = "";
287265
if (device == "ROCM") {
288266
dev_mem_str = "Hip";
289267
}
290268
if (device == "CUDA") {
291269
dev_mem_str = "Cuda";
292270
}
293-
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
294-
LOG(info) << "(ORT) Memory info set to on-device memory";
271+
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
272+
LOG(info) << "(ORT) Memory info set to on-device memory for device " << device << " with ID " << deviceIndex;
295273
}
296274
#endif
297275
std::vector<Ort::Value> inputTensor;

GPU/GPUTracking/Base/GPUReconstructionCPU.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class GPUReconstructionCPU : public GPUReconstructionKernels<GPUReconstructionCP
116116
virtual size_t TransferMemoryInternal(GPUMemoryResource* res, int32_t stream, deviceEvent* ev, deviceEvent* evList, int32_t nEvents, bool toGPU, const void* src, void* dst);
117117

118118
// ONNX runtime
119-
virtual void SetONNXGPUStream(Ort::SessionOptions*, int32_t, int32_t*) {}
119+
virtual void SetONNXGPUStream(Ort::SessionOptions&, int32_t, int32_t*) {}
120120

121121
int32_t InitDevice() override;
122122
int32_t ExitDevice() override;

GPU/GPUTracking/Base/GPUReconstructionProcessing.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
#include <functional>
2323
#include <atomic>
2424

25-
struct OrtSessionOptions;
25+
namespace Ort {
26+
struct SessionOptions;
27+
}
2628

2729
namespace o2::gpu
2830
{
@@ -90,7 +92,7 @@ class GPUReconstructionProcessing : public GPUReconstruction
9092
void AddGPUEvents(T*& events);
9193

9294
virtual std::unique_ptr<gpu_reconstruction_kernels::threadContext> GetThreadContext() override;
93-
virtual void SetONNXGPUStream(OrtSessionOptions*, int32_t, int32_t*) {}
95+
virtual void SetONNXGPUStream(Ort::SessionOptions&, int32_t, int32_t*) {}
9496

9597
struct RecoStepTimerMeta {
9698
HighResTimer timerToGPU;

GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -662,20 +662,19 @@ void GPUReconstructionCUDA::endGPUProfiling()
662662
}
663663

664664
#if defined(ORT_CUDA_BUILD) && ORT_CUDA_BUILD == 1
665-
void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions* session_options, int32_t stream, int32_t* deviceId)
665+
void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions& session_options, int32_t stream, int32_t* deviceId)
666666
{
667667
cudaGetDevice(deviceId);
668668
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
669669
CreateCUDAProviderOptions(&cuda_options);
670-
OrtSessionOptions* raw_options = session_options->operator OrtSessionOptions*();
671670

672671
// std::vector<const char*> keys{"device_id", "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"};
673672
// std::vector<const char*> values{"0", "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", "1"};
674673
// UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size());
675674

676675
// this implicitly sets "has_user_compute_stream"
677676
UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", &mInternals->Streams[stream]);
678-
Ort::ThrowOnError(SessionOptionsAppendExecutionProvider_CUDA_V2(raw_options, cuda_options));
677+
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
679678

680679
// Finally, don't forget to release the provider options
681680
ReleaseCUDAProviderOptions(cuda_options);
@@ -691,20 +690,23 @@ void* GPUReconstructionHIP::getGPUPointer(void* ptr)
691690
}
692691

693692
#if defined(ORT_ROCM_BUILD) && ORT_ROCM_BUILD == 1
694-
void GPUReconstructionHIP::SetONNXGPUStream(Ort::SessionOptions* session_options, int32_t stream, int32_t* deviceId)
693+
void GPUReconstructionHIP::SetONNXGPUStream(Ort::SessionOptions& session_options, int32_t stream, int32_t* deviceId)
695694
{
696695
// Create ROCm provider options
697696
cudaGetDevice(deviceId);
698697
const auto& api = Ort::GetApi();
699-
OrtROCMProviderOptions rocm_options{};
700-
rocm_options.has_user_compute_stream = 1; // Indicate that we are passing a user stream
701-
rocm_options.user_compute_stream = &mInternals->Streams[stream];
702-
703-
// Get the raw OrtSessionOptions pointer from the Ort::SessionOptions wrapper
704-
OrtSessionOptions* raw_options = session_options->operator OrtSessionOptions*();
705-
706-
// Append the ROCm execution provider with the custom HIP stream
707-
Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_ROCM(raw_options, &rocm_options));
698+
// api.GetCurrentGpuDeviceId(deviceId);
699+
OrtROCMProviderOptions rocm_options;
700+
LOG(info) << "Creating ROCm provider options";
701+
// rocm_options.has_user_compute_stream = 1; // Indicate that we are passing a user stream
702+
// LOG(info) << "Setting user compute stream";
703+
// rocm_options.user_compute_stream = &(mInternals->Streams[stream]);
704+
// LOG(info) << "Stream is set with streamId " << stream << " and reference " << &(mInternals->Streams[stream]);
705+
session_options.AppendExecutionProvider_ROCM(rocm_options);
706+
LOG(info) << "Appending ROCm provider options";
707+
// OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, *deviceId);
708+
// api.ReleaseROCMProviderOptions(rocm_options);
709+
LOG(info) << "Releasing ROCm provider options";
708710
}
709711

710712
#endif // GPUCA_HAS_ONNX

GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class GPUReconstructionCUDA : public GPUReconstructionKernels<GPUReconstructionC
8383
size_t GPUMemCpy(void* dst, const void* src, size_t size, int32_t stream, int32_t toGPU, deviceEvent* ev = nullptr, deviceEvent* evList = nullptr, int32_t nEvents = 1) override;
8484
void ReleaseEvent(deviceEvent ev) override;
8585
void RecordMarker(deviceEvent* ev, int32_t stream) override;
86-
void SetONNXGPUStream(Ort::SessionOptions* session_options, int32_t stream, int32_t* deviceId) override;
86+
void SetONNXGPUStream(Ort::SessionOptions& session_options, int32_t stream, int32_t* deviceId) override;
8787

8888
void GetITSTraits(std::unique_ptr<o2::its::TrackerTraits>* trackerTraits, std::unique_ptr<o2::its::VertexerTraits>* vertexerTraits, std::unique_ptr<o2::its::TimeFrame>* timeFrame) override;
8989

0 commit comments

Comments
 (0)