Skip to content

feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232

Open
shoumikhin wants to merge 10 commits intopytorch:mainfrom
shoumikhin:green-context-external-stream-upstream
Open

feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232
shoumikhin wants to merge 10 commits intopytorch:mainfrom
shoumikhin:green-context-external-stream-upstream

Conversation

@shoumikhin
Copy link
Copy Markdown

@shoumikhin shoumikhin commented May 1, 2026

Summary

Adds opt-in support for binding torch-tensorrt's TRT engine execution to externally-managed CUDA streams — typically streams created via cuGreenCtxStreamCreate for SM partitioning via CUDA Green Contexts (CUDA 12.4+). The motivating workload is edge / on-device multi-tenant inference on Jetson-class hardware where a vision encoder + policy net + diffusion head all share one process and need disjoint SM partitions to avoid time-slicing.

Currently, core/runtime/execute_engine.cpp lazily pulls a stream from torch's global stream pool on first execute. That pool is bound to the primary CUDA context, so even when a caller sets a green-context-bound stream as current (via c10::cuda::CUDAStreamGuard / torch.cuda.stream(...)), the TRT engine bypasses it and uses a primary-context pool stream — defeating any SM partitioning the caller set up.

Pure additive: no behavior change for callers that don't opt in.

Two complementary mechanisms

The PR ships two ways to bind a stream, sized to two different deployment shapes:

(1) Per-engine binding — for Python / dynamo / output_format="exported_program"

Reach the TRTEngine torchbind through the wrapping nn.Module's named_modules() and bind a stream per engine. This is the canonical multi-engine SM-partitioning case where one compiled model contains several TRT subgraphs that should each run on a distinct green context.

New C++ API on TRTEngine (exposed via torchbind):

void TRTEngine::set_external_stream(int64_t stream_handle);  // reinterpret_cast<int64_t>(cudaStream_t)
void TRTEngine::clear_external_stream();
int64_t TRTEngine::get_external_stream() const;
bool TRTEngine::is_external_stream_set() const;

Reachable from Python and external C++ via torch.classes.tensorrt.Engine.

New Python facade with RAII context-manager semantics:

import torch_tensorrt
from torch_tensorrt.runtime import set_external_stream, clear_external_stream

# Single stream bound to every TRT submodule
with set_external_stream(model, my_stream):
    out = model(x)            # restored on exit

# Per-engine binding (the canonical green-context case)
with set_external_stream(model, {
    "_run_on_acc_0": vision_encoder_stream,    # SM partition A
    "_run_on_acc_1": policy_net_stream,        # SM partition B
    "_run_on_acc_2": diffusion_head_stream,    # SM partition C
}):
    out = model(x)

set_external_stream walks named_modules() recursively, so deeply nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are reachable. Submodule names are dotted paths, validated up front so a bad value cannot leave a partially-bound module. The setter validates the stream's device-affinity against the engine's target device (via cuStreamGetCtx + cuCtxGetDevice) and rejects the legacy / per-thread magic stream IDs; the binding is applied atomically across multiple engines (any per-engine failure rolls back successfully-applied bindings before re-raising).

(2) Process-wide stream passthrough — for AOTI / .pt2 C++ deployments

When the model is exported with output_format="aot_inductor" and consumed in pure C++ via torch::inductor::AOTIModelPackageLoader, the live TRTEngine torchbind instances live inside OSSProxyExecutor::custom_objs_private with no public PyTorch accessor. Re-parsing the .pt2 only yields independent IValue copies that the running .so never invokes, so the per-engine API in (1) is unreachable.

The fix: a process-wide opt-in flag that makes execute_engine honor the caller's current CUDA stream instead of the lazy pool stream. Users wrap loader.run(...) in a CUDAStreamGuard and the engine inherits it.

New globals (C++ + Python):

namespace torch_tensorrt::core::runtime {
  bool get_engine_stream_passthrough();
  void set_engine_stream_passthrough(bool);
}
torch_tensorrt.runtime.set_engine_stream_passthrough(True)
torch_tensorrt.runtime.get_engine_stream_passthrough()

C++ usage after merge:

#include <ATen/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"

torch::inductor::AOTIModelPackageLoader loader("model.pt2");

// Process-wide opt-in — set once, applies to every loaded engine.
torch_tensorrt::core::runtime::set_engine_stream_passthrough(true);

// Carve out an SM partition with a Green Context and create a stream on it.
CUgreenCtx green_ctx;
CUstream raw_green_stream;
// ... cuDevSmResourceSplitByCount + cuGreenCtxCreate + cuGreenCtxStreamCreate ...

auto stream = c10::cuda::getStreamFromExternal(raw_green_stream, /*device=*/0);
{
  c10::cuda::CUDAStreamGuard guard(stream);
  auto out = loader.run(inputs);   // TRT engine inherits the guarded stream
}

Python usage after merge (also valid for AOTI-loaded models via torch._inductor.aoti_load_package):

import torch
import torch_tensorrt

torch_tensorrt.runtime.set_engine_stream_passthrough(True)

green_stream = torch.cuda.Stream(device=0)   # or wrap a green-ctx CUstream
with torch.cuda.stream(green_stream):
    out = aoti_model(x)

Precedence

When multiple sources are configured, the resolver picks in this order, every call (so set / clear take effect immediately without recreating the engine):

  1. Per-engine external_stream (set via TRTEngine::set_external_stream)
  2. Process-wide ENGINE_STREAM_PASSTHROUGH → caller's current CUDA stream
  3. Existing pool fallback (getStreamFromPool) — unchanged default behavior

Mutual exclusion with CUDA Graphs

Both mechanisms are mutually exclusive with CUDA Graphs. The check fires at bind time (set_external_stream / set_engine_stream_passthrough(true) throw if cudagraphs are currently enabled) and again at execute time as defense-in-depth (covers cudagraphs being enabled after the binding):

CUDA Graphs are not supported when an external stream is set on the engine.
Disable cudagraphs or call clear_external_stream() first.

CUDA Graphs are not supported while engine-stream passthrough is enabled.
Disable cudagraphs or call set_engine_stream_passthrough(False) first.

The setter and clearer also invalidate any captured graph (cudagraph.reset()) so a subsequent recapture happens cleanly and never replays against a stale stream identity.

Multi-GPU correctness fix folded in

TRTEngine::engine_stream and TRTEngine::caller_stream are now pinned to the engine's actual device_info.id in the constructor body. The in-class initializers at TRTEngine.h:211-212 default to device 0 (no device arg). Without this fix, the lazy pool re-acquire in execute_engine checked engine_stream == getDefaultCUDAStream(current_device_id) — always false on cuda:N for N>0 — so the engine ran on cuda:0's default stream regardless of the input device. Pre-existing bug; fixed here while we were in the area.

Files changed

File Change
core/runtime/TRTEngine.{h,cpp} per-engine setter / clearer / getter, external_stream + engine_stream_is_external fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
core/runtime/execute_engine.cpp stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool-fallback precedence; cudagraph mutual-exclusion guards for both
core/runtime/runtime.{h,cpp} ENGINE_STREAM_PASSTHROUGH global + get_/set_engine_stream_passthrough() accessors
core/runtime/register_jit_hooks.cpp torchbind exposure for the three per-engine methods + the two passthrough globals
py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py passthrough on the C++-backed runtime
py/torch_tensorrt/runtime/_external_stream.py top-level facade + context manager + passthrough toggles
py/torch_tensorrt/runtime/__init__.py re-export
tests/py/dynamo/runtime/test_006_external_stream.py both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)

Test plan

  • pytest tests/py/dynamo/runtime/test_006_external_stream.py — covers both PythonTorchTensorRTModule and TorchTensorRTModule runtime classes, including the new passthrough tests.
  • GPU runtime test on H100 / Hopper: register a green-context-bound stream, run a small TRT engine, verify via nsys profile that kernel launches are confined to the green context's SM partition.
  • GPU runtime test on Jetson Blackwell (sm_110).
  • AOTI C++ test: set_engine_stream_passthrough(true) + wrap AOTIModelPackageLoader::run() with a CUDAStreamGuard on a green-context stream, verify SM-partitioned execution in nsys.

Out of scope (future PRs)

  • Upstream PyTorch PR to add AOTIModelPackageLoader::get_custom_objs() so AOTI users can also use the per-engine API (when they want different streams per submodule inside one .pt2). The passthrough flag in this PR is the interim mechanism while that lands and reaches stable.
  • torch_tensorrt::aoti::TRTAOTILoader C++ wrapper behind a TORCH_TRT_HAVE_AOTI_CUSTOM_OBJS CMake probe — depends on the upstream PR.
  • NCCL + green context interaction. Distributed (NCCL collectives) on green-context-partitioned streams is not validated; the existing requires_native_multidevice path may need follow-up if a user combines both.

@meta-cla meta-cla Bot added the cla signed label May 1, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 1, 2026
@github-actions github-actions Bot requested a review from narendasan May 1, 2026 15:46
@shoumikhin shoumikhin force-pushed the green-context-external-stream-upstream branch from bfa0fea to a0434e4 Compare May 1, 2026 16:51
@narendasan narendasan requested a review from cehongwang May 1, 2026 17:02
@shoumikhin
Copy link
Copy Markdown
Author

shoumikhin commented May 1, 2026

Long-term plan: upstream PyTorch PR

Opened pytorch/pytorch#182149 to add AOTIModelPackageLoader::get_custom_objs(). Once it lands and reaches a tagged PyTorch release, AOTI / .pt2 C++ users can reach the live TRTEngine torchbind instances inside the loaded .so and use the per-engine set_external_stream API directly (no need for the process-wide set_engine_stream_passthrough flag).

torch::inductor::AOTIModelPackageLoader loader("model.pt2");
for (auto& [name, ivalue] : loader.get_custom_objs()) {
  if (auto e = ivalue.toCustomClass<torch_tensorrt::TRTEngine>()) {
    e->set_external_stream(reinterpret_cast<int64_t>(my_green_stream));
  }
}
loader.run(inputs);

This PR (#4232) ships set_engine_stream_passthrough as the interim mechanism so edge / on-device users are unblocked today. A follow-up torch_tensorrt::aoti::TRTAOTILoader wrapper (gated on a CMake TORCH_TRT_HAVE_AOTI_CUSTOM_OBJS probe) will provide the clean per-engine API once upstream lands.

shoumikhin and others added 10 commits May 4, 2026 14:47
…A stream

Adds opt-in support for binding torch-tensorrt's TRT engine execution to
externally-managed CUDA streams -- typically streams created via
`cuGreenCtxStreamCreate` for SM partitioning via CUDA Green Contexts (cuda
12.4+).

Currently, the runtime in `core/runtime/execute_engine.cpp` lazily pulls a
stream from torch's global stream pool on first execute. That pool is bound
to the primary CUDA context, so even when a caller sets a green-context-bound
stream as current, the TRT engine bypasses it and uses a primary-context pool
stream -- defeating any SM partitioning the caller set up.

Pure additive: no behavior change for callers that don't opt in.

This change adds two complementary mechanisms:

(1) Per-engine binding (Python / dynamo / Exported Program path):
    - C++ API on `TRTEngine` (exposed via torchbind):
        void set_external_stream(int64_t stream_handle);
        void clear_external_stream();
        int64_t get_external_stream() const;
      The handle is `reinterpret_cast<int64_t>(cudaStream_t)`. Reachable from
      Python and external C++ via `torch.classes.tensorrt.Engine`.
    - Python facade in `torch_tensorrt.runtime.set_external_stream(module, ...)`
      with optional per-engine binding via `Dict[submodule_name, StreamLike]`
      and RAII context-manager semantics. Walks `named_modules()` so deeply
      nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are
      reachable.

(2) Process-wide stream passthrough (AOTI / .pt2 C++ path):
    - New global flag `ENGINE_STREAM_PASSTHROUGH` and accessors:
        bool get_engine_stream_passthrough();
        void set_engine_stream_passthrough(bool);
      When enabled, `execute_engine` honors the caller's *current* CUDA stream
      (`c10::cuda::getCurrentCUDAStream`) instead of acquiring a pool stream.
      This unblocks `output_format="aot_inductor"` users whose `TRTEngine`
      torchbind constants live inside `OSSProxyExecutor::custom_objs_`
      (private, no public PyTorch accessor) and so are unreachable for the
      per-engine API. Users wrap `loader.run(...)` in a `CUDAStreamGuard`
      bound to e.g. a Green Context stream and the engine inherits it.
    - Python facade: `torch_tensorrt.runtime.set_engine_stream_passthrough(bool)`
      / `get_engine_stream_passthrough()`.

Mutual exclusion with CUDA Graphs is enforced for both mechanisms (throws at
execute time). Setter and clearer also invalidate any captured graph so a
subsequent recapture happens cleanly (avoids replaying against a stale stream
identity).

Multi-GPU correctness: `engine_stream` and `caller_stream` are now pinned to
the engine's actual `device_info.id` in the constructor body (the in-class
initializers default to device 0; without this, the lazy pool re-acquire in
`execute_engine` skipped firing on `cuda:N` for `N>0` because the
`engine_stream == getDefaultCUDAStream(current_device_id)` check was always
false).

Same code path serves both the C++ AOTI runtime (model.so dispatch into
`execute_engine.cpp` via the C-shim) and the dynamo Python runtime
(`PythonTorchTensorRTModule`). Per-engine binding lets callers map distinct
green contexts to distinct TRT subgraphs in one compiled model. The
process-wide passthrough is the alternative for callers who can't reach the
engines individually (AOTI's private custom_objs_ map being the canonical
case).

Files changed:
- core/runtime/TRTEngine.{h,cpp}                                 setter / clearer / getter, `external_stream` and `engine_stream_is_external` fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
- core/runtime/execute_engine.cpp                                stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool fallback precedence; cudagraph mutual-exclusion guards
- core/runtime/runtime.{h,cpp}                                   `ENGINE_STREAM_PASSTHROUGH` global + accessors
- core/runtime/register_jit_hooks.cpp                            torchbind exposure for all three per-engine methods + the two passthrough globals
- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py       passthrough on the C++-backed runtime
- py/torch_tensorrt/runtime/_external_stream.py                  top-level facade + context manager + passthrough toggles
- py/torch_tensorrt/runtime/__init__.py                          re-export
- tests/py/dynamo/runtime/test_006_external_stream.py            covers both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)

Test plan:
- pytest tests/py/dynamo/runtime/test_006_external_stream.py -- both
  PythonTorchTensorRTModule and TorchTensorRTModule runtime classes.
- GPU runtime test on H100 / Hopper: register a green-context-bound stream,
  run a small TRT engine, verify via nsys profile that kernel launches are
  confined to the green context's SM partition.
- GPU runtime test on Jetson Thor (Blackwell): same as above with sm_110.
- AOTI C++ test: `set_engine_stream_passthrough(True)`, wrap
  `AOTIModelPackageLoader::run()` with a `CUDAStreamGuard` on a green-context
  stream, verify SM-partitioned execution.

Open item (deliberately not in this commit, can land separately):
- Device-affinity validation in `set_external_stream`. The current sanity
  check (`cudaStreamGetFlags`) confirms the handle is real but does not
  validate the stream's device against `device_info.id`. A multi-GPU caller
  could silently bind a wrong-device stream. Clean fix uses `cuStreamGetCtx`
  + `cuCtxGetDevice` (driver API) or `cudaStreamGetDevice` (CUDA 12.5+).
Bundle 1 (must-fix from reviewers):

1. Device-affinity validation in set_external_stream
   - cuStreamGetCtx + cuCtxPushCurrent + cuCtxGetDevice resolves the stream's
     device and asserts it matches engine.device_info.id. Catches the silent
     cross-device launch (cuda:1 stream bound to cuda:0 engine) before any
     enqueueV3, where the failure would otherwise surface as a confusing
     CUDA error far from the bind site.

2. Reject magic stream values
   - cudaStreamLegacy / cudaStreamPerThread are now explicitly rejected.
     Binding them latches engine_stream_is_external onto a non-isolated
     stream that defeats the whole point of the API.

3. Atomic rollback on partial multi-engine bind (Python facade)
   - The set_external_stream loop now records each successful application
     and reverses them on any failure, so an engine's per-handle validation
     throwing midway through a Dict-shaped binding can no longer leave
     earlier engines in a half-bound state.

4. Re-entrancy / deadlock fix
   - mu is now std::recursive_mutex everywhere on TRTEngine. Allows TRT
     plugin -> Python -> set_external_stream re-entry on the same thread
     without self-deadlock. Zero downside for the non-reentrant path.

5. Cudagraph mutual-exclusion check moved to set time
   - set_external_stream now asserts CUDAGRAPHS_MODE == STANDARD up front
     instead of waiting until next execute. Faster failure, clearer call
     site, no wasted input migration etc. before the throw. The execute-
     time guard remains as defense-in-depth (covers cudagraphs being
     enabled AFTER an external stream is bound).

6. is_external_stream_set() companion accessor
   - Avoids the ambiguous get_external_stream() == 0 sentinel pattern.
     ABI-safe, cheap, exposed via torchbind.

7. Error message typo fix
   - 'wraps a non-null CUDA stream is required' -> 'must wrap a non-null
     CUDA stream'.

Defer to follow-up: Python torch.cuda.default_stream(self.device) one-char
fix, additional tests (green-context smoke, restore-non-zero-prior,
serialize round-trip), passthrough relocation, NCCL+external_stream
LOG_WARNING.
cudaStreamGetDevice was added in CUDA 12.8 (not 12.5 as the guard claimed)
and is still missing from Jetpack aarch64 toolchains even on cu126/cu128.
Restrict the device-affinity check to x86_64 + CUDA >= 12080; fall through
to the existing LOG_WARNING fallback elsewhere.
@shoumikhin shoumikhin force-pushed the green-context-external-stream-upstream branch from 08ab339 to 405c706 Compare May 4, 2026 21:48
@shoumikhin
Copy link
Copy Markdown
Author

Friendly ping @narendasan @cehongwang.

This has been open for a week. The upstream PyTorch dependency (pytorch/pytorch#182149) is landing today, which unblocks the full end-to-end path.

Could one of you take a first pass when you get a chance? Happy to address feedback or split it up if that helps review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant