-
Notifications
You must be signed in to change notification settings - Fork 395
feat(runtime): file-lock the TRT-RTX runtime cache #4237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
tp5uiuc
wants to merge
3
commits into
pytorch:main
Choose a base branch
from
tp5uiuc:feat/runtime-cache-file-lock
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| #include <algorithm> | ||
| #include <filesystem> | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include "NvInfer.h" | ||
|
|
@@ -61,26 +62,28 @@ void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims | |
| } | ||
|
|
||
| TRTEngine::TRTEngine( | ||
| const std::string& serialized_engine, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& _in_binding_names, | ||
| const std::vector<std::string>& _out_binding_names, | ||
| const Platform& target_platform, | ||
| bool hardware_compatible, | ||
| bool requires_output_allocator, | ||
| const std::string& serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy) | ||
| std::string serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy, | ||
| TRTRuntimeConfig runtime_cfg) | ||
| : TRTEngine( | ||
| "deserialized_trt", | ||
| serialized_engine, | ||
| std::move(serialized_engine), | ||
| cuda_device, | ||
| _in_binding_names, | ||
| _out_binding_names, | ||
| target_platform, | ||
| hardware_compatible, | ||
| requires_output_allocator, | ||
| serialized_metadata, | ||
| resource_allocation_strategy) {} | ||
| std::move(serialized_metadata), | ||
| resource_allocation_strategy, | ||
| std::move(runtime_cfg)) {} | ||
|
|
||
| TRTEngine::TRTEngine(std::vector<std::string> serialized_info) | ||
| : TRTEngine( | ||
|
|
@@ -95,24 +98,27 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) | |
| serialized_info[SERIALIZED_METADATA_IDX], | ||
| (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) | ||
| ? ResourceAllocationStrategy::kDynamic | ||
| : ResourceAllocationStrategy::kStatic)) { | ||
| : ResourceAllocationStrategy::kStatic), | ||
| make_runtime_config_from_serialized(serialized_info)) { | ||
| this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]); | ||
| if (this->requires_native_multidevice) { | ||
| LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution"); | ||
| } | ||
| } | ||
|
|
||
| TRTEngine::TRTEngine( | ||
| const std::string& mod_name, | ||
| const std::string& serialized_engine, | ||
| std::string mod_name, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& _in_binding_names, | ||
| const std::vector<std::string>& _out_binding_names, | ||
| const Platform& target_platform, | ||
| bool hardware_compatible, | ||
| bool requires_output_allocator, | ||
| const std::string& serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy) { | ||
| std::string serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy, | ||
| TRTRuntimeConfig runtime_cfg) { | ||
| this->runtime_cfg = std::move(runtime_cfg); | ||
| TORCHTRT_CHECK( | ||
| is_supported_on_current_platform(target_platform), | ||
| "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " | ||
|
|
@@ -123,15 +129,15 @@ TRTEngine::TRTEngine( | |
| auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); | ||
| TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); | ||
|
|
||
| this->serialized_metadata = serialized_metadata; | ||
| this->serialized_metadata = std::move(serialized_metadata); | ||
| this->requires_output_allocator = requires_output_allocator; | ||
| device_info = most_compatible_device.value(); | ||
| multi_gpu_device_check(); | ||
| set_rt_device(device_info); | ||
|
|
||
| rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); | ||
|
|
||
| name = slugify(mod_name); | ||
| name = slugify(std::move(mod_name)); | ||
|
|
||
| cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); | ||
| TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); | ||
|
|
@@ -146,13 +152,7 @@ TRTEngine::TRTEngine( | |
| LOG_DEBUG( | ||
| "Resource allocation strategy: " | ||
| << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); | ||
| if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { | ||
| this->exec_ctx = | ||
| make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); | ||
| } else { | ||
| this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| } | ||
| TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); | ||
| recreate_execution_context(); | ||
|
|
||
| // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) | ||
| cudaMalloc(&empty_tensor_placeholder, 1); | ||
|
|
@@ -288,6 +288,9 @@ TRTEngine::TRTEngine( | |
| } | ||
|
|
||
| TRTEngine::~TRTEngine() { | ||
| // Marked noexcept so safe to invoke from a destructor without | ||
| // explicit try/catch; any I/O error is logged internally. | ||
| runtime_cfg.save_runtime_cache(); | ||
| trt_engine_profiler.reset(); | ||
| exec_ctx.reset(); | ||
| cuda_engine.reset(); | ||
|
|
@@ -301,8 +304,7 @@ void TRTEngine::disable_profiling() { | |
| torch::cuda::synchronize(device_info.id); | ||
| profile_execution = false; | ||
| trt_engine_profiler.reset(); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); | ||
| recreate_execution_context(); | ||
| } | ||
|
|
||
| void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { | ||
|
|
@@ -399,10 +401,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { | |
| trt_engine_profiler.reset(); | ||
| } | ||
| bool result = cuda_engine->setWeightStreamingBudgetV2(budget); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| TORCHTRT_CHECK( | ||
| (exec_ctx.get() != nullptr), | ||
| "Unable to recreate TensorRT execution context after setting new device memory budget"); | ||
| recreate_execution_context(); | ||
| if (profile_execution) { | ||
| enable_profiling(); | ||
| } | ||
|
|
@@ -459,6 +458,7 @@ std::string TRTEngine::to_str() const { | |
| ss << " Target Platform: " << target_platform << std::endl; | ||
| ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; | ||
| ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; | ||
| ss << runtime_cfg.to_str(); | ||
| // clang-format on | ||
| return ss.str(); | ||
| } | ||
|
|
@@ -495,7 +495,14 @@ FlattenedState TRTEngine::__obj_flatten__() { | |
| std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), | ||
| std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), | ||
| std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), | ||
| std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); | ||
| std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]) | ||
| #ifdef TRT_MAJOR_RTX | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above comment |
||
| , | ||
| std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), | ||
| std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), | ||
| std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) | ||
| #endif | ||
| ); | ||
| } | ||
|
|
||
| std::vector<std::string> TRTEngine::serialize() { | ||
|
|
@@ -522,6 +529,13 @@ std::vector<std::string> TRTEngine::serialize() { | |
| this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; | ||
| serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; | ||
| // rank/world_size are runtime facts (may differ at load time); not serialized. | ||
| #ifdef TRT_MAJOR_RTX | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
| serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; | ||
| serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( | ||
| static_cast<std::underlying_type_t<DynamicShapesKernelStrategy>>(runtime_cfg.dynamic_shapes_kernel_strategy)); | ||
| serialized_info[CUDA_GRAPH_STRATEGY_IDX] = | ||
| std::to_string(static_cast<std::underlying_type_t<CudaGraphStrategyOption>>(runtime_cfg.cuda_graph_strategy)); | ||
| #endif | ||
|
|
||
| return serialized_info; | ||
| } | ||
|
|
@@ -533,14 +547,11 @@ void TRTEngine::reset_captured_graph() { | |
| void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { | ||
| if (new_strategy != this->resource_allocation_strategy) { | ||
| this->resource_allocation_strategy = new_strategy; | ||
| if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { | ||
| LOG_DEBUG("Setting resource allocation strategy to dynamic"); | ||
| this->exec_ctx = | ||
| make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); | ||
| } else { | ||
| LOG_DEBUG("Setting resource allocation strategy to static"); | ||
| this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| } | ||
| LOG_DEBUG( | ||
| "Setting resource allocation strategy to " | ||
| << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" | ||
| : "static")); | ||
| recreate_execution_context(); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -637,19 +648,42 @@ void TRTEngine::release_nccl_comm() { | |
| LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'"); | ||
| torch::cuda::synchronize(device_info.id); | ||
| this->exec_ctx.reset(); | ||
| if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { | ||
| this->exec_ctx = | ||
| make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); | ||
| } else { | ||
| this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| } | ||
| TORCHTRT_CHECK( | ||
| (exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm"); | ||
| recreate_execution_context(); | ||
| this->nccl_initialized = false; | ||
| LOG_INFO("NCCL communicator released from engine '" << this->name << "'"); | ||
| } | ||
| #endif // ENABLE_TRT_NCCL_COLLECTIVES | ||
|
|
||
| bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { | ||
| return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); | ||
| } | ||
|
|
||
| void TRTEngine::disable_rtx_native_cudagraphs() { | ||
| bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; | ||
| runtime_cfg.disable_rtx_native_cudagraphs(name); | ||
| if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { | ||
| // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx | ||
| // so the new strategy takes effect for subsequent enqueueV3 calls. | ||
| recreate_execution_context(); | ||
| } | ||
| } | ||
|
|
||
| void TRTEngine::recreate_execution_context() { | ||
| // Flush any kernels the previous execution context may have compiled into the | ||
| // runtime cache before creating the replacement. The destructor also saves, but | ||
| // doing it here guards against losing compiled kernels across profiling toggles, | ||
| // allocator changes, or process kills that happen between allocator changes and | ||
| // teardown. No-op on standard TensorRT or when no cache path is configured. | ||
| runtime_cfg.save_runtime_cache(); | ||
| runtime_cfg.ensure_initialized(cuda_engine.get()); | ||
| runtime_cfg.set_execution_context_allocation_strategy( | ||
| resource_allocation_strategy == ResourceAllocationStrategy::kDynamic | ||
| ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED | ||
| : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); | ||
| TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); | ||
| } | ||
|
|
||
| } // namespace runtime | ||
| } // namespace core | ||
| } // namespace torch_tensorrt | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this need to be a deep copy?