Skip to content

Commit e9b2d16

Browse files
committed
CUDA ORT: Must use api struct to call functions
1 parent 427e840 commit e9b2d16

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,24 +621,34 @@ void GPUReconstructionCUDA::loadKernelModules(bool perKernel)
621621
}
622622
}
623623

624+
#define ORTCHK(command) \
625+
{ \
626+
OrtStatus* status = command; \
627+
if (status != nullptr) { \
628+
const char* msg = api->GetErrorMessage(status); \
629+
GPUFatal("ONNXRuntime Error: %s", msg); \
630+
} \
631+
}
632+
624633
void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions& session_options, int32_t stream, int32_t* deviceId)
625634
{
626635
GPUChkErr(cudaGetDevice(deviceId));
627636
#if !defined(__HIPCC__) && defined(ORT_CUDA_BUILD)
637+
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
628638
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
629-
CreateCUDAProviderOptions(&cuda_options);
639+
ORTCHK(api->CreateCUDAProviderOptions(&cuda_options));
630640

631641
// std::vector<const char*> keys{"device_id", "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"};
632642
// std::vector<const char*> values{"0", "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", "1"};
633643
// UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size());
634644

635645
// this implicitly sets "has_user_compute_stream"
636-
cuda_options.has_user_compute_stream = 1;
637-
UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", mInternals->Streams[stream]);
646+
cuda_options->has_user_compute_stream = 1;
647+
ORTCHK(api->UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", mInternals->Streams[stream]));
638648
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
639649

640650
// Finally, don't forget to release the provider options
641-
ReleaseCUDAProviderOptions(cuda_options);
651+
api->ReleaseCUDAProviderOptions(cuda_options);
642652
#elif defined(ORT_ROCM_BUILD)
643653
// const auto& api = Ort::GetApi();
644654
// api.GetCurrentGpuDeviceId(deviceId);

0 commit comments

Comments
 (0)