[Feature][Performance] Triton backend for GRU / LSTM with intermediate resets#3738
Merged
Conversation
…e resets
Adds ``recurrent_backend="triton"`` to :class:`~torchrl.modules.GRUModule`
and :class:`~torchrl.modules.LSTMModule`. The kernel fuses the whole time
loop into a single CUDA launch and applies the ``is_init`` reset cheaply
inside the loop, avoiding the ``pack_padded_sequence`` / ``pad_packed_sequence``
round trip that dominates the cuDNN ``pad`` backend whenever any reset
occurs mid-rollout.
Kernels live in ``torchrl/modules/tensordict_module/_rnn_triton.py`` and
are gated by a ``_has_triton`` optional-dep flag. Both forward and
backward are K-tiled along the recurrent contraction, so a single
``[H, H]`` weight slab fits in shared memory at any hidden size we care
about in RL. ``recurrent_compute_dtype`` selects fp32 (TF32 on H100,
default; matches cuDNN behavior) or bf16 (twice the SMEM margin, ~7-bit
mantissa).
Limitations of this prototype:
* ``num_layers == 1`` only.
* No dropout, no projection (``proj_size``), no bidirectional.
* Hidden size internally padded to the next power of two (wrapper-side).
Tests:
* ``test/test_tensordictmodules.py``: numerical agreement against ``pad``
for forward and backward at H ∈ {16, 64}, fp32 and bf16, multiple
intra-rollout resets. ``RuntimeError`` test for missing triton dep.
Benchmark:
* ``benchmarks/bench_gru_reset_backends.py``: adds ``triton_td`` mode.
H100 fwd+bwd at B=2048, T=256, H=128, fp32:
| reset | cudnn | scan | triton |
|-------|--------|--------|--------|
| 0% | 33.8 | 70.9 | 11.7 |
| 1% | 309.1 | 71.1 | 11.7 |
| 10% | 310.6 | 71.2 | 11.7 |
Triton is reset-invariant and 5x-25x faster than cuDNN at non-zero reset
probabilities. cuDNN still wins at reset=0% for LSTM at H=256+ (its
persistent-LSTM path is hard to beat); the gap is small for GRU.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3738
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 14e2a61 with merge base 1f1f8bf ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Collaborator
Author
Remaining work before this is release-readyTracking what would need to land before flipping the triton backend from prototype to first-class: Correctness gaps
Production-grade integration
Coverage
|
* Rename ``I`` -> ``I_in`` to silence flake8 E741 (ambiguous single-letter name). * Add ``dout`` to codespell ignore-list — it's standard autograd shorthand for ``grad of out``, not a typo of ``doubt``. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
… + dropout * New ``--dropouts`` CLI sweep (defaults to ``[0.0]`` for backwards compat). * Multi-layer is now exercised through the triton backend — the previous ``num_layers == 1`` gate was stale. * Scan modes are skipped automatically when ``dropout > 0`` since they raise ``NotImplementedError``. * Output CSV now includes ``num_layers`` and ``dropout`` columns so a single run can drive a 2D sweep. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Lift ``--num-layers`` from a single int to a sweep (``nargs='+'``) so one invocation can cover several stack depths. Includes the layer count in the output CSV. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds ``test_{lstm,gru}_module_three_backends_equivalent`` parametrized over
``num_layers in {1, 2}``. Compares pad / scan / triton outputs (feat and
``next.hidden*``) on a tensordict with intra-rollout resets.
This locks in the contract that all three backends agree wherever the
intersection of their supported features lies. The case is fixed at
``dropout=0`` because the scan backend raises on dropout; the
pad-vs-triton dropout path is already covered by
``test_*_triton_extended_forward_matches_pad``.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collaborator
Author
|
Follow-up PR:
|
…+training * Older Triton versions read ``prune_configs_by`` as a plain dict and access ``perf_model`` / ``top_k`` keys unconditionally. Add a ``_PRUNE_BY_TEMPLATE`` with those keys set to ``None`` and merge it into every autotune declaration. Fixes ``KeyError: 'perf_model'`` on ``tests-olddeps``. * The extended forward tests parametrized ``dropout=0.3`` with ``training=True``, then asserted bit-exact equivalence between pad and triton outputs. That cannot hold: cuDNN's ``nn.LSTM`` / ``nn.GRU`` stores its dropout mask state in a cuDNN dropout descriptor that advances independently of torch's global RNG, while the triton wrapper applies ``F.dropout`` directly. Even with ``torch.manual_seed`` before each call the masks diverge. Under that combination, fall back to shape/dtype/finite checks instead. Fixes intermittent assertion failures on ``tests-optdeps``. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ibility ``tl.extra.libdevice`` only exists on Triton >= 2.2; on the ``tests-olddeps (3.10, 11.8)`` runner the import raises ``AttributeError: module 'triton.language.extra' has no attribute 'libdevice'``. Replace ``tl.extra.libdevice.tanh(x)`` with a ``2*sigmoid(2x) - 1`` helper, which lowers to the same hardware ``ex2``/``rcp`` pair via ``tl.sigmoid`` and is available in every Triton version we care about. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Even with the sigmoid-tanh shim added in 9d72835, the ``tests-olddeps (3.10, 11.8)`` runner's Triton (< 2.2) cannot compile the backward kernel — it rejects ``tl.atomic_add`` with a 2-D mask. Rather than chase a moving compatibility target, restore the canonical ``tl.extra.libdevice.tanh`` path and instead gate ``_has_triton`` on the presence of the ``triton.language.extra.libdevice`` submodule (the proxy for Triton >= 2.2). Older Triton installs return ``_has_triton = False``, the Triton tests skip cleanly, and users transparently fall back to the ``scan`` / ``pad`` backends. The probe uses ``importlib.util.find_spec`` rather than ``hasattr(tl.extra, 'libdevice')`` because ``libdevice`` is a lazily loaded submodule on modern Triton — Triton's JIT resolves the attribute access at kernel-compile time, but at module-import time the attribute isn't bound yet. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Re-format ``_rnn_triton.py``, ``rnn.py`` and ``test_tensordictmodules.py`` with ``ufmt 2.8.0`` (black 22.3.0 + usort 1.0.3) to satisfy the ``python-source-and-configs`` lint job. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
4 tasks
vmoens
added a commit
that referenced
this pull request
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
recurrent_backend="triton"toGRUModuleandLSTMModule. The low-level kernels fuse the time loop for one layer and one direction into a single CUDA launch and apply theis_initreset cheaply inside the loop, avoiding thepack_padded_sequence/pad_packed_sequenceround trip that dominates the cuDNNpadbackend whenever any reset occurs mid-rollout. Sits alongside the existingpadandscanbackends.torchrl/modules/tensordict_module/_rnn_triton.py— forward and backward kernels, both K-tiled along the recurrent contraction._has_triton. ClearRuntimeErrorif the backend is selected without triton installed.recurrent_compute_dtypeselects fp32 (TF32 on Ampere/Hopper, default — matches cuDNN behavior) or bf16.num_layers > 1is handled by stacking per-layer Triton calls.dc_t).Limitations (prototype)
proj_sizeare not native Triton kernel paths yet. These configurations currently use a pad-compatible correctness path rather than getting the Triton speedup.reserve_spacefootprint.tl.atomic_addfor reset-state gradients, so deterministic mode is not implemented yet.Numerical correctness
Forward + backward match the
padbackend within TF32 noise (~5e-4 rel err for fp32, ~5e-3 for bf16) across multiple shape / reset_prob combinations. The test coverage now includes multilayer GRU/LSTM, dropout train/eval behavior, bidirectional/projection correctness fallback, and LSTM losses that backpropagate through("next", "hidden1").Performance (H100, B=2048, T=256, H=128, fp32, fwd+bwd)
pad)scan_compiletritonTriton is reset-invariant and 5–25x faster than cuDNN at non-zero reset probabilities. cuDNN still wins at reset=0% for LSTM at H>=256 (its persistent-LSTM path is hard to beat without TMA + warp specialisation). The gap is small for GRU.
Test plan
test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_backend_matches_pad— forward agreement vspadat H in {16, 64} x {fp32, bf16}, with several mid-rollout resets.test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_backward— gradient agreement vspadat H=64.test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_extended_forward_matches_pad— multilayer, dropout train/eval, bidirectional fallback, projection fallback, and combined projected bidirectional LSTM forward parity.test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_extended_backward_matches_pad— multilayer, bidirectional fallback, projection fallback, and intermediate cell-state gradient parity.GRUModulewhere applicable.benchmarks/bench_gru_reset_backends.pyextended withtriton_td,num_layers, dropout sweeps, and optionaltorch.compileflags.Follow-ups (intentionally out of scope here)
proj_sizekernels.torch.compilefirst-class integration viatorch.library.custom_op(right now the kernel runs inside compiled graphs but as an opaque node)._GRUFn.apply/_LSTMFn.apply.