|
16 | 16 | #include "GPUReconstructionCUDAIncludesHost.h" |
17 | 17 |
|
18 | 18 | #include <cuda_profiler_api.h> |
| 19 | +#include "ML/OrtInterface.h" |
19 | 20 |
|
20 | 21 | #include "GPUReconstructionCUDA.h" |
21 | 22 | #include "GPUReconstructionCUDAInternals.h" |
|
35 | 36 | #undef GPUCA_KRNL |
36 | 37 | #endif |
37 | 38 |
|
| 39 | +#ifdef GPUCA_HAS_ONNX |
| 40 | +#include <onnxruntime_cxx_api.h> |
| 41 | +#endif |
| 42 | + |
38 | 43 | static constexpr size_t REQUIRE_MIN_MEMORY = 1024L * 1024 * 1024; |
39 | 44 | static constexpr size_t REQUIRE_MEMORY_RESERVED = 512L * 1024 * 1024; |
40 | 45 | static constexpr size_t REQUIRE_FREE_MEMORY_RESERVED_PER_SM = 40L * 1024 * 1024; |
@@ -656,13 +661,50 @@ void GPUReconstructionCUDA::endGPUProfiling() |
656 | 661 | { |
657 | 662 | GPUChkErr(cudaProfilerStop()); |
658 | 663 | } |
| 664 | + |
| 665 | +#ifdef GPUCA_HAS_ONNX |
| 666 | +int32_t GPUReconstructionCUDA::SetONNXGPUStream(OrtSessionOptions* session_options, int32_t stream) |
| 667 | +{ |
| 668 | + OrtCUDAProviderOptionsV2* cuda_options = nullptr; |
| 669 | + CreateCUDAProviderOptions(&cuda_options); |
| 670 | + |
| 671 | + // std::vector<const char*> keys{"device_id", "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"}; |
| 672 | + // std::vector<const char*> values{"0", "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", "1"}; |
| 673 | + // UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size()); |
| 674 | + |
| 675 | + // this implicitly sets "has_user_compute_stream" |
| 676 | + UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", &mInternals->Streams[stream]); |
| 677 | + Ort::ThrowOnError(SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options)); |
| 678 | + |
| 679 | + // Finally, don't forget to release the provider options |
| 680 | + ReleaseCUDAProviderOptions(cuda_options); |
| 681 | + |
| 682 | + return 0; |
| 683 | +} |
| 684 | +#endif // GPUCA_HAS_ONNX |
| 685 | + |
659 | 686 | #else // HIP |
660 | 687 | void* GPUReconstructionHIP::getGPUPointer(void* ptr) |
661 | 688 | { |
662 | 689 | void* retVal = nullptr; |
663 | 690 | GPUChkErr(hipHostGetDevicePointer(&retVal, ptr, 0)); |
664 | 691 | return retVal; |
665 | 692 | } |
| 693 | + |
| 694 | +#ifdef GPUCA_HAS_ONNX |
| 695 | +int32_t GPUReconstructionCUDA::SetONNXGPUStream(OrtSessionOptions* session_options, int32_t stream) |
| 696 | +{ |
| 697 | + // Create ROCm provider options |
| 698 | + const auto& api = Ort::GetApi(); |
| 699 | + OrtROCMProviderOptions rocm_options{}; |
| 700 | + rocm_options.has_user_compute_stream = 1; // Indicate that we are passing a user stream |
| 701 | + rocm_options.user_compute_stream = &mInternals->Streams[stream]; |
| 702 | + |
| 703 | + // Append the ROCm execution provider with the custom HIP stream |
| 704 | + Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_ROCM(session_options, &rocm_options)); |
| 705 | + return 0; |
| 706 | +} |
| 707 | +#endif // GPUCA_HAS_ONNX |
666 | 708 | #endif // __HIPCC__ |
667 | 709 |
|
668 | 710 | namespace o2::gpu |
|
0 commit comments