feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232
feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232shoumikhin wants to merge 10 commits intopytorch:mainfrom
Conversation
bfa0fea to
a0434e4
Compare
Long-term plan: upstream PyTorch PROpened pytorch/pytorch#182149 to add torch::inductor::AOTIModelPackageLoader loader("model.pt2");
for (auto& [name, ivalue] : loader.get_custom_objs()) {
if (auto e = ivalue.toCustomClass<torch_tensorrt::TRTEngine>()) {
e->set_external_stream(reinterpret_cast<int64_t>(my_green_stream));
}
}
loader.run(inputs);This PR (#4232) ships |
…A stream
Adds opt-in support for binding torch-tensorrt's TRT engine execution to
externally-managed CUDA streams -- typically streams created via
`cuGreenCtxStreamCreate` for SM partitioning via CUDA Green Contexts (cuda
12.4+).
Currently, the runtime in `core/runtime/execute_engine.cpp` lazily pulls a
stream from torch's global stream pool on first execute. That pool is bound
to the primary CUDA context, so even when a caller sets a green-context-bound
stream as current, the TRT engine bypasses it and uses a primary-context pool
stream -- defeating any SM partitioning the caller set up.
Pure additive: no behavior change for callers that don't opt in.
This change adds two complementary mechanisms:
(1) Per-engine binding (Python / dynamo / Exported Program path):
- C++ API on `TRTEngine` (exposed via torchbind):
void set_external_stream(int64_t stream_handle);
void clear_external_stream();
int64_t get_external_stream() const;
The handle is `reinterpret_cast<int64_t>(cudaStream_t)`. Reachable from
Python and external C++ via `torch.classes.tensorrt.Engine`.
- Python facade in `torch_tensorrt.runtime.set_external_stream(module, ...)`
with optional per-engine binding via `Dict[submodule_name, StreamLike]`
and RAII context-manager semantics. Walks `named_modules()` so deeply
nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are
reachable.
(2) Process-wide stream passthrough (AOTI / .pt2 C++ path):
- New global flag `ENGINE_STREAM_PASSTHROUGH` and accessors:
bool get_engine_stream_passthrough();
void set_engine_stream_passthrough(bool);
When enabled, `execute_engine` honors the caller's *current* CUDA stream
(`c10::cuda::getCurrentCUDAStream`) instead of acquiring a pool stream.
This unblocks `output_format="aot_inductor"` users whose `TRTEngine`
torchbind constants live inside `OSSProxyExecutor::custom_objs_`
(private, no public PyTorch accessor) and so are unreachable for the
per-engine API. Users wrap `loader.run(...)` in a `CUDAStreamGuard`
bound to e.g. a Green Context stream and the engine inherits it.
- Python facade: `torch_tensorrt.runtime.set_engine_stream_passthrough(bool)`
/ `get_engine_stream_passthrough()`.
Mutual exclusion with CUDA Graphs is enforced for both mechanisms (throws at
execute time). Setter and clearer also invalidate any captured graph so a
subsequent recapture happens cleanly (avoids replaying against a stale stream
identity).
Multi-GPU correctness: `engine_stream` and `caller_stream` are now pinned to
the engine's actual `device_info.id` in the constructor body (the in-class
initializers default to device 0; without this, the lazy pool re-acquire in
`execute_engine` skipped firing on `cuda:N` for `N>0` because the
`engine_stream == getDefaultCUDAStream(current_device_id)` check was always
false).
Same code path serves both the C++ AOTI runtime (model.so dispatch into
`execute_engine.cpp` via the C-shim) and the dynamo Python runtime
(`PythonTorchTensorRTModule`). Per-engine binding lets callers map distinct
green contexts to distinct TRT subgraphs in one compiled model. The
process-wide passthrough is the alternative for callers who can't reach the
engines individually (AOTI's private custom_objs_ map being the canonical
case).
Files changed:
- core/runtime/TRTEngine.{h,cpp} setter / clearer / getter, `external_stream` and `engine_stream_is_external` fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
- core/runtime/execute_engine.cpp stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool fallback precedence; cudagraph mutual-exclusion guards
- core/runtime/runtime.{h,cpp} `ENGINE_STREAM_PASSTHROUGH` global + accessors
- core/runtime/register_jit_hooks.cpp torchbind exposure for all three per-engine methods + the two passthrough globals
- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py passthrough on the C++-backed runtime
- py/torch_tensorrt/runtime/_external_stream.py top-level facade + context manager + passthrough toggles
- py/torch_tensorrt/runtime/__init__.py re-export
- tests/py/dynamo/runtime/test_006_external_stream.py covers both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)
Test plan:
- pytest tests/py/dynamo/runtime/test_006_external_stream.py -- both
PythonTorchTensorRTModule and TorchTensorRTModule runtime classes.
- GPU runtime test on H100 / Hopper: register a green-context-bound stream,
run a small TRT engine, verify via nsys profile that kernel launches are
confined to the green context's SM partition.
- GPU runtime test on Jetson Thor (Blackwell): same as above with sm_110.
- AOTI C++ test: `set_engine_stream_passthrough(True)`, wrap
`AOTIModelPackageLoader::run()` with a `CUDAStreamGuard` on a green-context
stream, verify SM-partitioned execution.
Open item (deliberately not in this commit, can land separately):
- Device-affinity validation in `set_external_stream`. The current sanity
check (`cudaStreamGetFlags`) confirms the handle is real but does not
validate the stream's device against `device_info.id`. A multi-GPU caller
could silently bind a wrong-device stream. Clean fix uses `cuStreamGetCtx`
+ `cuCtxGetDevice` (driver API) or `cudaStreamGetDevice` (CUDA 12.5+).
Bundle 1 (must-fix from reviewers):
1. Device-affinity validation in set_external_stream
- cuStreamGetCtx + cuCtxPushCurrent + cuCtxGetDevice resolves the stream's
device and asserts it matches engine.device_info.id. Catches the silent
cross-device launch (cuda:1 stream bound to cuda:0 engine) before any
enqueueV3, where the failure would otherwise surface as a confusing
CUDA error far from the bind site.
2. Reject magic stream values
- cudaStreamLegacy / cudaStreamPerThread are now explicitly rejected.
Binding them latches engine_stream_is_external onto a non-isolated
stream that defeats the whole point of the API.
3. Atomic rollback on partial multi-engine bind (Python facade)
- The set_external_stream loop now records each successful application
and reverses them on any failure, so an engine's per-handle validation
throwing midway through a Dict-shaped binding can no longer leave
earlier engines in a half-bound state.
4. Re-entrancy / deadlock fix
- mu is now std::recursive_mutex everywhere on TRTEngine. Allows TRT
plugin -> Python -> set_external_stream re-entry on the same thread
without self-deadlock. Zero downside for the non-reentrant path.
5. Cudagraph mutual-exclusion check moved to set time
- set_external_stream now asserts CUDAGRAPHS_MODE == STANDARD up front
instead of waiting until next execute. Faster failure, clearer call
site, no wasted input migration etc. before the throw. The execute-
time guard remains as defense-in-depth (covers cudagraphs being
enabled AFTER an external stream is bound).
6. is_external_stream_set() companion accessor
- Avoids the ambiguous get_external_stream() == 0 sentinel pattern.
ABI-safe, cheap, exposed via torchbind.
7. Error message typo fix
- 'wraps a non-null CUDA stream is required' -> 'must wrap a non-null
CUDA stream'.
Defer to follow-up: Python torch.cuda.default_stream(self.device) one-char
fix, additional tests (green-context smoke, restore-non-zero-prior,
serialize round-trip), passthrough relocation, NCCL+external_stream
LOG_WARNING.
… avoid libcuda link on Jetpack
cudaStreamGetDevice was added in CUDA 12.8 (not 12.5 as the guard claimed) and is still missing from Jetpack aarch64 toolchains even on cu126/cu128. Restrict the device-affinity check to x86_64 + CUDA >= 12080; fall through to the existing LOG_WARNING fallback elsewhere.
08ab339 to
405c706
Compare
|
Friendly ping @narendasan @cehongwang. This has been open for a week. The upstream PyTorch dependency (pytorch/pytorch#182149) is landing today, which unblocks the full end-to-end path. Could one of you take a first pass when you get a chance? Happy to address feedback or split it up if that helps review. |
Summary
Adds opt-in support for binding torch-tensorrt's TRT engine execution to externally-managed CUDA streams — typically streams created via
cuGreenCtxStreamCreatefor SM partitioning via CUDA Green Contexts (CUDA 12.4+). The motivating workload is edge / on-device multi-tenant inference on Jetson-class hardware where a vision encoder + policy net + diffusion head all share one process and need disjoint SM partitions to avoid time-slicing.Currently,
core/runtime/execute_engine.cpplazily pulls a stream from torch's global stream pool on first execute. That pool is bound to the primary CUDA context, so even when a caller sets a green-context-bound stream as current (viac10::cuda::CUDAStreamGuard/torch.cuda.stream(...)), the TRT engine bypasses it and uses a primary-context pool stream — defeating any SM partitioning the caller set up.Pure additive: no behavior change for callers that don't opt in.
Two complementary mechanisms
The PR ships two ways to bind a stream, sized to two different deployment shapes:
(1) Per-engine binding — for Python / dynamo /
output_format="exported_program"Reach the
TRTEnginetorchbind through the wrappingnn.Module'snamed_modules()and bind a stream per engine. This is the canonical multi-engine SM-partitioning case where one compiled model contains several TRT subgraphs that should each run on a distinct green context.New C++ API on
TRTEngine(exposed via torchbind):Reachable from Python and external C++ via
torch.classes.tensorrt.Engine.New Python facade with RAII context-manager semantics:
set_external_streamwalksnamed_modules()recursively, so deeply nested TRT submodules (e.g. HF blocks under wrapperGraphModules) are reachable. Submodule names are dotted paths, validated up front so a bad value cannot leave a partially-bound module. The setter validates the stream's device-affinity against the engine's target device (viacuStreamGetCtx+cuCtxGetDevice) and rejects the legacy / per-thread magic stream IDs; the binding is applied atomically across multiple engines (any per-engine failure rolls back successfully-applied bindings before re-raising).(2) Process-wide stream passthrough — for AOTI /
.pt2C++ deploymentsWhen the model is exported with
output_format="aot_inductor"and consumed in pure C++ viatorch::inductor::AOTIModelPackageLoader, the liveTRTEnginetorchbind instances live insideOSSProxyExecutor::custom_objs_— private with no public PyTorch accessor. Re-parsing the.pt2only yields independentIValuecopies that the running.sonever invokes, so the per-engine API in (1) is unreachable.The fix: a process-wide opt-in flag that makes
execute_enginehonor the caller's current CUDA stream instead of the lazy pool stream. Users wraploader.run(...)in aCUDAStreamGuardand the engine inherits it.New globals (C++ + Python):
C++ usage after merge:
Python usage after merge (also valid for AOTI-loaded models via
torch._inductor.aoti_load_package):Precedence
When multiple sources are configured, the resolver picks in this order, every call (so
set/cleartake effect immediately without recreating the engine):external_stream(set viaTRTEngine::set_external_stream)ENGINE_STREAM_PASSTHROUGH→ caller's current CUDA streamgetStreamFromPool) — unchanged default behaviorMutual exclusion with CUDA Graphs
Both mechanisms are mutually exclusive with CUDA Graphs. The check fires at bind time (
set_external_stream/set_engine_stream_passthrough(true)throw if cudagraphs are currently enabled) and again at execute time as defense-in-depth (covers cudagraphs being enabled after the binding):The setter and clearer also invalidate any captured graph (
cudagraph.reset()) so a subsequent recapture happens cleanly and never replays against a stale stream identity.Multi-GPU correctness fix folded in
TRTEngine::engine_streamandTRTEngine::caller_streamare now pinned to the engine's actualdevice_info.idin the constructor body. The in-class initializers atTRTEngine.h:211-212default to device 0 (no device arg). Without this fix, the lazy pool re-acquire inexecute_enginecheckedengine_stream == getDefaultCUDAStream(current_device_id)— always false oncuda:NforN>0— so the engine ran oncuda:0's default stream regardless of the input device. Pre-existing bug; fixed here while we were in the area.Files changed
core/runtime/TRTEngine.{h,cpp}external_stream+engine_stream_is_externalfields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctorcore/runtime/execute_engine.cppcore/runtime/runtime.{h,cpp}ENGINE_STREAM_PASSTHROUGHglobal +get_/set_engine_stream_passthrough()accessorscore/runtime/register_jit_hooks.cpppy/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.pypy/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.pypy/torch_tensorrt/runtime/_external_stream.pypy/torch_tensorrt/runtime/__init__.pytests/py/dynamo/runtime/test_006_external_stream.pyTest plan
pytest tests/py/dynamo/runtime/test_006_external_stream.py— covers bothPythonTorchTensorRTModuleandTorchTensorRTModuleruntime classes, including the new passthrough tests.nsys profilethat kernel launches are confined to the green context's SM partition.set_engine_stream_passthrough(true)+ wrapAOTIModelPackageLoader::run()with aCUDAStreamGuardon a green-context stream, verify SM-partitioned execution innsys.Out of scope (future PRs)
AOTIModelPackageLoader::get_custom_objs()so AOTI users can also use the per-engine API (when they want different streams per submodule inside one.pt2). The passthrough flag in this PR is the interim mechanism while that lands and reaches stable.torch_tensorrt::aoti::TRTAOTILoaderC++ wrapper behind aTORCH_TRT_HAVE_AOTI_CUSTOM_OBJSCMake probe — depends on the upstream PR.requires_native_multidevicepath may need follow-up if a user combines both.