Skip to content

[Feature][Performance] Triton backend for GRU / LSTM with intermediate resets#3738

Merged
vmoens merged 11 commits into
pytorch:mainfrom
vmoens:rnn-cuda-kernel
May 12, 2026
Merged

[Feature][Performance] Triton backend for GRU / LSTM with intermediate resets#3738
vmoens merged 11 commits into
pytorch:mainfrom
vmoens:rnn-cuda-kernel

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 11, 2026

Summary

Adds recurrent_backend="triton" to GRUModule and LSTMModule. The low-level kernels fuse the time loop for one layer and one direction into a single CUDA launch and apply the is_init reset cheaply inside the loop, avoiding the pack_padded_sequence / pad_packed_sequence round trip that dominates the cuDNN pad backend whenever any reset occurs mid-rollout. Sits alongside the existing pad and scan backends.

  • New file: torchrl/modules/tensordict_module/_rnn_triton.py — forward and backward kernels, both K-tiled along the recurrent contraction.
  • Optional dep, guarded by _has_triton. Clear RuntimeError if the backend is selected without triton installed.
  • recurrent_compute_dtype selects fp32 (TF32 on Ampere/Hopper, default — matches cuDNN behavior) or bf16.
  • Hidden size is internally padded to the next power of two (Python-side; no in-kernel masking).
  • num_layers > 1 is handled by stacking per-layer Triton calls.
  • Dropout is applied between per-layer Triton calls, matching PyTorch's recurrent dropout semantics.
  • LSTM backward now propagates gradients flowing through intermediate cell states (dc_t).

Limitations (prototype)

  • The low-level kernels operate on one layer and one direction at a time; multilayer execution launches one Triton call per layer.
  • Bidirectional and LSTM proj_size are not native Triton kernel paths yet. These configurations currently use a pad-compatible correctness path rather than getting the Triton speedup.
  • The autograd wrapper saves per-layer gate activations explicitly, so activation memory scales linearly with layer count and can exceed cuDNN's opaque reserve_space footprint.
  • Shape changes can trigger separate Triton compilation/autotune. Fixed training shapes are the intended fast path.
  • Backward uses tl.atomic_add for reset-state gradients, so deterministic mode is not implemented yet.

Numerical correctness

Forward + backward match the pad backend within TF32 noise (~5e-4 rel err for fp32, ~5e-3 for bf16) across multiple shape / reset_prob combinations. The test coverage now includes multilayer GRU/LSTM, dropout train/eval behavior, bidirectional/projection correctness fallback, and LSTM losses that backpropagate through ("next", "hidden1").

Performance (H100, B=2048, T=256, H=128, fp32, fwd+bwd)

reset_prob cuDNN (pad) scan_compile triton
0% 33.8 ms 70.9 ms 11.7 ms
1% 309.1 ms 71.1 ms 11.7 ms
10% 310.6 ms 71.2 ms 11.7 ms

Triton is reset-invariant and 5–25x faster than cuDNN at non-zero reset probabilities. cuDNN still wins at reset=0% for LSTM at H>=256 (its persistent-LSTM path is hard to beat without TMA + warp specialisation). The gap is small for GRU.

Test plan

  • test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_backend_matches_pad — forward agreement vs pad at H in {16, 64} x {fp32, bf16}, with several mid-rollout resets.
  • test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_backward — gradient agreement vs pad at H=64.
  • test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_extended_forward_matches_pad — multilayer, dropout train/eval, bidirectional fallback, projection fallback, and combined projected bidirectional LSTM forward parity.
  • test/test_tensordictmodules.py::TestLSTMModule::test_lstm_module_triton_extended_backward_matches_pad — multilayer, bidirectional fallback, projection fallback, and intermediate cell-state gradient parity.
  • Same coverage for GRUModule where applicable.
  • All tests skip cleanly when triton is unavailable or CUDA is missing.
  • benchmarks/bench_gru_reset_backends.py extended with triton_td, num_layers, dropout sweeps, and optional torch.compile flags.

Follow-ups (intentionally out of scope here)

  • Native Triton bidirectional and LSTM proj_size kernels.
  • torch.compile first-class integration via torch.library.custom_op (right now the kernel runs inside compiled graphs but as an opaque node).
  • vmap support for _GRUFn.apply / _LSTMFn.apply.
  • CPU fallback or explicit CPU tensor error.
  • In-kernel masking for non-power-of-two H to eliminate the wrapper pad/unpad copies.
  • Deterministic backward mode for reset-state gradient accumulation.
  • Docs/tutorial/SOTA integration once the prototype scope is accepted.

…e resets

Adds ``recurrent_backend="triton"`` to :class:`~torchrl.modules.GRUModule`
and :class:`~torchrl.modules.LSTMModule`. The kernel fuses the whole time
loop into a single CUDA launch and applies the ``is_init`` reset cheaply
inside the loop, avoiding the ``pack_padded_sequence`` / ``pad_packed_sequence``
round trip that dominates the cuDNN ``pad`` backend whenever any reset
occurs mid-rollout.

Kernels live in ``torchrl/modules/tensordict_module/_rnn_triton.py`` and
are gated by a ``_has_triton`` optional-dep flag. Both forward and
backward are K-tiled along the recurrent contraction, so a single
``[H, H]`` weight slab fits in shared memory at any hidden size we care
about in RL. ``recurrent_compute_dtype`` selects fp32 (TF32 on H100,
default; matches cuDNN behavior) or bf16 (twice the SMEM margin, ~7-bit
mantissa).

Limitations of this prototype:

* ``num_layers == 1`` only.
* No dropout, no projection (``proj_size``), no bidirectional.
* Hidden size internally padded to the next power of two (wrapper-side).

Tests:
* ``test/test_tensordictmodules.py``: numerical agreement against ``pad``
  for forward and backward at H ∈ {16, 64}, fp32 and bf16, multiple
  intra-rollout resets. ``RuntimeError`` test for missing triton dep.

Benchmark:
* ``benchmarks/bench_gru_reset_backends.py``: adds ``triton_td`` mode.

H100 fwd+bwd at B=2048, T=256, H=128, fp32:

  | reset |  cudnn |   scan | triton |
  |-------|--------|--------|--------|
  |  0%   |  33.8  |  70.9  |  11.7  |
  |  1%   | 309.1  |  71.1  |  11.7  |
  | 10%   | 310.6  |  71.2  |  11.7  |

Triton is reset-invariant and 5x-25x faster than cuDNN at non-zero reset
probabilities. cuDNN still wins at reset=0% for LSTM at H=256+ (its
persistent-LSTM path is hard to beat); the gap is small for GRU.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3738

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 14e2a61 with merge base 1f1f8bf (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 11, 2026
@vmoens
Copy link
Copy Markdown
Collaborator Author

vmoens commented May 11, 2026

Remaining work before this is release-ready

Tracking what would need to land before flipping the triton backend from prototype to first-class:

Correctness gaps

  • num_layers > 1 — supported by stacking per-layer Triton calls.
  • Dropout — applied between per-layer Triton calls, matching PyTorch's layer-dropout semantics.
  • Bidirectional and proj_size — not supported by the Triton kernels yet; current correctness path falls back to pad-compatible execution.
  • LSTM intermediate dc_t — the backward kernel now propagates gradients flowing through intermediate cell states.

Production-grade integration

  • torch.compile first-class support — wrap as @torch.library.custom_op + register_fake + register_autograd. Today the kernel runs in compiled graphs but as an opaque node (no fusion with surrounding ops).
  • vmap — _GRUFn.apply / _LSTMFn.apply aren't vmap-aware.
  • CPU fallback — currently CUDA-only; need either a vanilla torch implementation or a hard error on CPU tensors.
  • In-kernel H masking — eliminate the wrapper-side pad/unpad copies.
  • Determinism — backward uses tl.atomic_add for dh0 / dc0, which can give tiny run-to-run gradient differences. Either expose a deterministic flag that switches to a single-pass reduction or document the non-determinism.

Coverage

  • Sota-implementations CI — add a recurrent SAC/PPO config that exercises recurrent_backend="triton" end-to-end (per CLAUDE.md §11).
  • Docs — extend docs/source/reference/modules.rst with a paragraph on the new backend.
  • Tutorial — short page contrasting pad / scan / triton on the resets-during-RL workload, with the benchmark table (per CLAUDE.md §9).
  • Wider numerical test coverage — current tests cover H ∈ {16, 64}. Extend to H ∈ {128, 256} and add a property-style test on randomized is_init masks.

vmoens and others added 2 commits May 11, 2026 12:55
* Rename ``I`` -> ``I_in`` to silence flake8 E741 (ambiguous single-letter name).
* Add ``dout`` to codespell ignore-list — it's standard autograd shorthand
  for ``grad of out``, not a typo of ``doubt``.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@vmoens vmoens force-pushed the rnn-cuda-kernel branch from bf79dd5 to ecd1087 Compare May 11, 2026 14:20
vmoens and others added 4 commits May 11, 2026 16:06
… + dropout

* New ``--dropouts`` CLI sweep (defaults to ``[0.0]`` for backwards compat).
* Multi-layer is now exercised through the triton backend — the previous
  ``num_layers == 1`` gate was stale.
* Scan modes are skipped automatically when ``dropout > 0`` since they raise
  ``NotImplementedError``.
* Output CSV now includes ``num_layers`` and ``dropout`` columns so a single
  run can drive a 2D sweep.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Lift ``--num-layers`` from a single int to a sweep (``nargs='+'``) so one
invocation can cover several stack depths. Includes the layer count in the
output CSV.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds ``test_{lstm,gru}_module_three_backends_equivalent`` parametrized over
``num_layers in {1, 2}``. Compares pad / scan / triton outputs (feat and
``next.hidden*``) on a tensordict with intra-rollout resets.

This locks in the contract that all three backends agree wherever the
intersection of their supported features lies. The case is fixed at
``dropout=0`` because the scan backend raises on dropout; the
pad-vs-triton dropout path is already covered by
``test_*_triton_extended_forward_matches_pad``.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@vmoens
Copy link
Copy Markdown
Collaborator Author

vmoens commented May 11, 2026

Follow-up PR:

  • Native Triton bidirectional / proj_size
  • torch.library.custom_op compile integration
  • vmap
  • in-kernel H masking
  • deterministic backward mode
  • full docs/tutorial/SOTA coverage

vmoens and others added 4 commits May 11, 2026 18:58
…+training

* Older Triton versions read ``prune_configs_by`` as a plain dict and access
  ``perf_model`` / ``top_k`` keys unconditionally. Add a ``_PRUNE_BY_TEMPLATE``
  with those keys set to ``None`` and merge it into every autotune
  declaration. Fixes ``KeyError: 'perf_model'`` on ``tests-olddeps``.

* The extended forward tests parametrized ``dropout=0.3`` with
  ``training=True``, then asserted bit-exact equivalence between pad and
  triton outputs. That cannot hold: cuDNN's ``nn.LSTM`` / ``nn.GRU`` stores
  its dropout mask state in a cuDNN dropout descriptor that advances
  independently of torch's global RNG, while the triton wrapper applies
  ``F.dropout`` directly. Even with ``torch.manual_seed`` before each call
  the masks diverge. Under that combination, fall back to shape/dtype/finite
  checks instead. Fixes intermittent assertion failures on ``tests-optdeps``.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ibility

``tl.extra.libdevice`` only exists on Triton >= 2.2; on the
``tests-olddeps (3.10, 11.8)`` runner the import raises
``AttributeError: module 'triton.language.extra' has no attribute
'libdevice'``. Replace ``tl.extra.libdevice.tanh(x)`` with a
``2*sigmoid(2x) - 1`` helper, which lowers to the same hardware
``ex2``/``rcp`` pair via ``tl.sigmoid`` and is available in every
Triton version we care about.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Even with the sigmoid-tanh shim added in 9d72835, the
``tests-olddeps (3.10, 11.8)`` runner's Triton (< 2.2) cannot compile the
backward kernel — it rejects ``tl.atomic_add`` with a 2-D mask. Rather
than chase a moving compatibility target, restore the canonical
``tl.extra.libdevice.tanh`` path and instead gate ``_has_triton`` on the
presence of the ``triton.language.extra.libdevice`` submodule (the proxy
for Triton >= 2.2). Older Triton installs return ``_has_triton = False``,
the Triton tests skip cleanly, and users transparently fall back to the
``scan`` / ``pad`` backends.

The probe uses ``importlib.util.find_spec`` rather than
``hasattr(tl.extra, 'libdevice')`` because ``libdevice`` is a lazily
loaded submodule on modern Triton — Triton's JIT resolves the attribute
access at kernel-compile time, but at module-import time the attribute
isn't bound yet.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Re-format ``_rnn_triton.py``, ``rnn.py`` and ``test_tensordictmodules.py``
with ``ufmt 2.8.0`` (black 22.3.0 + usort 1.0.3) to satisfy the
``python-source-and-configs`` lint job.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@vmoens vmoens merged commit aac5227 into pytorch:main May 12, 2026
106 of 109 checks passed
@vmoens vmoens deleted the rnn-cuda-kernel branch May 12, 2026 07:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Benchmarks rl/benchmark changes CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Feature New feature Integrations/torch_geometric Integrations Modules

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant