Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ TRTEngine::TRTEngine(
multi_gpu_device_check();
set_rt_device(device_info);

// Pin the default-stream sentinels to the engine's actual device. The
// header-side initializers default to device 0; without this, the
// ``engine_stream == getDefaultCUDAStream(current_device_id)`` check in
// execute_engine never fires on cuda:N (N>0), so the lazy pool-stream
// re-acquire is skipped and the engine ends up running on device 0's
// default stream. Done before the exec_ctx is built so any later
// device-affinity assertion sees a coherent baseline.
this->engine_stream = c10::cuda::getDefaultCUDAStream(device_info.id);
this->caller_stream = c10::cuda::getDefaultCUDAStream(device_info.id);

rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));

name = slugify(mod_name);
Expand Down Expand Up @@ -334,6 +344,82 @@ bool TRTEngine::are_output_tensors_unowned() {
return this->output_tensors_are_unowned;
}

void TRTEngine::set_external_stream(int64_t stream_handle) {
TORCHTRT_CHECK(
stream_handle != 0,
"External stream handle must be non-zero. Use clear_external_stream() to revert to the default stream pool.");
// Reject the legacy / per-thread / null magic stream IDs. cuStreamGetCtx
// accepts these but binding them latches the engine onto a non-isolated
// stream that defeats the whole point of an external binding.
TORCHTRT_CHECK(
stream_handle != reinterpret_cast<int64_t>(cudaStreamLegacy) &&
stream_handle != reinterpret_cast<int64_t>(cudaStreamPerThread),
"set_external_stream: legacy/per-thread/default stream handles are not supported "
"(use clear_external_stream() to revert to the default stream pool).");
auto stream = reinterpret_cast<cudaStream_t>(stream_handle);
// Cheap sanity check that this is a real CUstream.
unsigned int flags = 0;
TORCHTRT_CHECK(
cudaStreamGetFlags(stream, &flags) == cudaSuccess, "set_external_stream: stream handle is not a valid CUstream");
// Device-affinity check: resolve the stream's device and compare against the
// engine's target. Catches cuda:1 stream bound to cuda:0 engine etc. -- without
// this the cross-device launch surfaces as a confusing CUDA error far from the
// bind site. Uses the runtime API (CUDA 12.8+ on x86_64; the symbol is missing
// on Jetpack 12.6 even though its CUDART_VERSION reports >= 12080) so we don't
// drag in libcuda.
#if CUDART_VERSION >= 12080 && !defined(__aarch64__)
int stream_dev = -1;
TORCHTRT_CHECK(
cudaStreamGetDevice(stream, &stream_dev) == cudaSuccess,
"set_external_stream: cudaStreamGetDevice failed for handle " << stream_handle);
TORCHTRT_CHECK(
static_cast<int64_t>(stream_dev) == device_info.id,
"External stream is on device " << stream_dev << " but engine targets device " << device_info.id);
#else
// CUDA < 12.8 (or aarch64 / Jetpack where the symbol isn't shipped even on
// 12.8+) lacks the runtime-API stream-device accessor and the alternative
// driver API would force a libcuda link, which we deliberately avoid for
// Jetpack builds. Warn loudly so callers know the cross-device misuse will
// only surface as a confusing CUDA error at enqueue time.
LOG_WARNING(
"set_external_stream: this build (CUDA < 12.8 or aarch64/Jetpack): "
"device-affinity validation is skipped. Caller is responsible for "
"ensuring the stream's device matches the engine's target device (id="
<< device_info.id << ").");
#endif
// Lock so the cudagraph-vs-external guard in execute_engine sees a
// consistent snapshot for the duration of one execute() call.
std::lock_guard<std::recursive_mutex> lock(mu);
// Fail fast if cudagraphs are currently active -- catches the misuse at
// bind time rather than next execute (faster failure, clearer call site,
// no wasted input migration etc. before the throw).
TORCHTRT_CHECK(
CUDAGRAPHS_MODE != SUBGRAPH_CUDAGRAPHS && CUDAGRAPHS_MODE != WHOLE_GRAPH_CUDAGRAPHS,
"Cannot bind an external stream while CUDA Graphs are enabled. "
"Disable cudagraphs first (torch_tensorrt.runtime.set_cudagraphs_mode(STANDARD)).");
external_stream = stream;
// A previously-captured graph records the prior engine_stream identity;
// replaying it on the new external stream would be UB.
cudagraph.reset();
}

void TRTEngine::clear_external_stream() {
std::lock_guard<std::recursive_mutex> lock(mu);
external_stream = nullptr;
// The next cudagraphs-enabled call must recapture against a pool stream.
cudagraph.reset();
}

bool TRTEngine::is_external_stream_set() const {
std::lock_guard<std::recursive_mutex> lock(mu);
return external_stream != nullptr;
}

int64_t TRTEngine::get_external_stream() const {
std::lock_guard<std::recursive_mutex> lock(mu);
return reinterpret_cast<int64_t>(external_stream);
}

void TRTEngine::set_profile_format(std::string format) {
if (format == "trex") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
Expand Down
14 changes: 13 additions & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ struct TRTEngine : torch::CustomClassHolder {
void set_pre_allocated_outputs(bool enable);
void set_output_tensors_as_unowned(bool enable);
bool are_output_tensors_unowned();
// External CUDA stream binding (e.g., for CUDA Green Contexts, cuda 12.4+).
// Caller owns the stream lifetime and must clear before destroying it.
// Mutually exclusive with CUDA Graphs (throws at execute time).
void set_external_stream(int64_t stream_handle);
void clear_external_stream();
int64_t get_external_stream() const;
// Returns true iff an external stream is currently bound. Avoids the
// ambiguous ``get_external_stream() == 0`` sentinel pattern.
bool is_external_stream_set() const;
TorchTRTRuntimeStates runtime_states;
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';
Expand All @@ -196,6 +205,9 @@ struct TRTEngine : torch::CustomClassHolder {
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
// Runtime-only state, never serialized. Both fields accessed under `mu`.
cudaStream_t external_stream = nullptr;
bool engine_stream_is_external = false;
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key = "None";
Expand Down Expand Up @@ -252,7 +264,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::string enqueue_profile_path;
std::string trt_engine_profile_path;
std::string cuda_graph_debug_path;
std::mutex mu;
mutable std::recursive_mutex mu;
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
ResourceAllocationStrategy resource_allocation_strategy = kStatic;
void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy);
Expand Down
53 changes: 48 additions & 5 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,29 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
// Re-resolve every call so set/clear take effect without recreating the engine.
// Precedence: per-engine external_stream > process-wide ENGINE_STREAM_PASSTHROUGH
// > existing pool fallback. The provenance flag forces a pool re-acquire when
// external is cleared, even if engine_stream is no longer the default (would
// otherwise hold a stale wrapper).
if (auto external = compiled_engine->external_stream) {
compiled_engine->engine_stream = c10::cuda::getStreamFromExternal(external, current_device_id);
compiled_engine->engine_stream_is_external = true;
} else if (ENGINE_STREAM_PASSTHROUGH) {
// Honor the caller's current CUDA stream verbatim. Intended for AOTI /
// pt2 C++ deployments where the engine torchbinds aren't reachable
// from outside the AOTIModelPackageLoader; users wrap loader.run() with
// a CUDAStreamGuard bound to e.g. a Green Context stream. Mark as
// ``external`` so that toggling passthrough off later forces a pool
// re-acquire on the next call (otherwise the stale caller-stream
// wrapper would be reused indefinitely).
compiled_engine->engine_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
compiled_engine->engine_stream_is_external = true;
} else if (
compiled_engine->engine_stream_is_external ||
compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
compiled_engine->engine_stream_is_external = false;
}

{ // Engine Execution (execute on engine stream)
Expand Down Expand Up @@ -425,9 +445,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
// Re-resolve every call so set/clear take effect without recreating the engine.
// Precedence: per-engine external_stream > process-wide ENGINE_STREAM_PASSTHROUGH
// > existing pool fallback. See run_standard_execution above for details.
if (auto external = compiled_engine->external_stream) {
compiled_engine->engine_stream = c10::cuda::getStreamFromExternal(external, current_device_id);
compiled_engine->engine_stream_is_external = true;
} else if (ENGINE_STREAM_PASSTHROUGH) {
compiled_engine->engine_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
compiled_engine->engine_stream_is_external = true;
} else if (
compiled_engine->engine_stream_is_external ||
compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
compiled_engine->engine_stream_is_external = false;
}

{ // Engine Execution (execute on engine stream)
Expand Down Expand Up @@ -494,7 +525,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
<< "); Hardware Compatible: " << compiled_engine->hardware_compatible);
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
// Other IExecutionContext methods and runtime states should be in same scope as well
std::unique_lock<std::mutex> lock(compiled_engine->mu);
std::unique_lock<std::recursive_mutex> lock(compiled_engine->mu);
if (compiled_engine->profile_execution) {
std::stringstream ss;
ss << "Execution profiling is enabled, find results here:" << std::endl;
Expand All @@ -511,6 +542,18 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);

// Captured graphs record the engine_stream identity; replay after the
// caller-owned stream is destroyed is UB. The same applies to the process-wide
// passthrough mode, where every call may resolve to a different caller stream.
TORCHTRT_CHECK(
!(cudagraphs_enabled && compiled_engine->external_stream != nullptr),
"CUDA Graphs are not supported when an external stream is set on the engine. "
"Disable cudagraphs or call clear_external_stream() first.");
TORCHTRT_CHECK(
!(cudagraphs_enabled && ENGINE_STREAM_PASSTHROUGH),
"CUDA Graphs are not supported while engine-stream passthrough is enabled. "
"Disable cudagraphs or call set_engine_stream_passthrough(False) first.");

if (MULTI_DEVICE_SAFE_MODE) {
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
if (compiled_engine->profile_execution) {
Expand Down
8 changes: 8 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)
.def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned)
.def("set_external_stream", &TRTEngine::set_external_stream)
.def("clear_external_stream", &TRTEngine::clear_external_stream)
.def("get_external_stream", &TRTEngine::get_external_stream)
.def("is_external_stream_set", &TRTEngine::is_external_stream_set)
.def(
"use_dynamically_allocated_resources",
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
Expand Down Expand Up @@ -170,6 +174,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("get_engine_stream_passthrough", []() -> bool { return ENGINE_STREAM_PASSTHROUGH; });
m.def("set_engine_stream_passthrough", [](bool engine_stream_passthrough) -> void {
ENGINE_STREAM_PASSTHROUGH = engine_stream_passthrough;
});
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
Expand Down
9 changes: 9 additions & 0 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;
bool ENGINE_STREAM_PASSTHROUGH = false;
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;

c10::optional<RTDevice> get_most_compatible_device(
Expand Down Expand Up @@ -130,6 +131,14 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
}

bool get_engine_stream_passthrough() {
return ENGINE_STREAM_PASSTHROUGH;
}

void set_engine_stream_passthrough(bool engine_stream_passthrough) {
ENGINE_STREAM_PASSTHROUGH = engine_stream_passthrough;
}

CudaGraphsMode get_cudagraphs_mode() {
return CUDAGRAPHS_MODE;
}
Expand Down
14 changes: 14 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "9";
extern bool MULTI_DEVICE_SAFE_MODE;
// Process-wide opt-in: when true, ``execute_engine`` honors the caller's
// current CUDA stream (``c10::cuda::getCurrentCUDAStream``) instead of the
// usual lazy pool-stream acquire. Intended for AOTI / ``.pt2`` C++ deployments
// where the TRT engine torchbind constants cannot be reached from outside the
// ``AOTIModelPackageLoader`` (PyTorch does not expose them) and so the
// per-engine ``set_external_stream`` API is not callable. With this flag on,
// users wrap ``loader.run(...)`` in a ``c10::cuda::CUDAStreamGuard`` bound to
// e.g. a CUDA Green Context stream and the engine inherits it. Mutually
// exclusive with CUDA Graphs (asserted at execute time).
extern bool ENGINE_STREAM_PASSTHROUGH;

typedef enum {
STANDARD = 0,
Expand Down Expand Up @@ -61,6 +71,10 @@ bool get_multi_device_safe_mode();

void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_engine_stream_passthrough();

void set_engine_stream_passthrough(bool engine_stream_passthrough);

CudaGraphsMode get_cudagraphs_mode();

void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);
Expand Down
Loading
Loading