Skip to content

[Performance] Add compile integration for Triton RNN kernels#3740

Open
vmoens wants to merge 8 commits into
pytorch:mainfrom
vmoens:rnn-cuda-kernel-2
Open

[Performance] Add compile integration for Triton RNN kernels#3740
vmoens wants to merge 8 commits into
pytorch:mainfrom
vmoens:rnn-cuda-kernel-2

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 12, 2026

Summary

  • Wrap low-level GRU/LSTM Triton forward and backward launches in torch.library.custom_op so torch.compile sees opaque traceable operators.
  • Register fake/meta kernels and conservative vmap rules for the custom ops.
  • Extend RNN tests with CUDA compile-forward/backward and vmap-vs-loop parity coverage.
  • Extend the reset-backend benchmark with a --cudagraph option and document observed compile/cudagraph behavior.

Testing

  • git diff --cached --check before commit
  • python -m py_compile benchmarks/bench_gru_reset_backends.py torchrl/modules/tensordict_module/_rnn_triton.py torchrl/modules/tensordict_module/rnn.py test/test_tensordictmodules.py
  • PYTHONPATH=. pytest test/test_tensordictmodules.py -k "triton and (gru_module or lstm_module or custom_op_compile or custom_op_vmap)"
    • Local result: 2 passed, 29 skipped, 260 deselected; CUDA/Triton tests skipped on this machine.

Notes

  • The vmap registration intentionally uses map semantics and launches one Triton call per mapped slice.
  • The Triton custom op remains opaque to compile; the main win is compatibility/fullgraph capture rather than fusing through the kernel body.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3740

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit 5e750e8 with merge base 3df2f4a (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 12, 2026
@github-actions github-actions Bot added Performance Performance issue or suggestion for improvement Benchmarks rl/benchmark changes Modules Integrations/torch_geometric Integrations labels May 12, 2026
vmoens added 7 commits May 12, 2026 14:26
- Extract the validate/flatten/fallback/unflatten scaffolding into
  ``_vmap_backward_via_flatten``; GRU and LSTM backward vmap rules only
  need a per-op ``_invoke`` closure that unpacks the flattened args,
  rebuilds ``shapes`` from them, and calls the impl.
- Add a comment in ``_gru_backward_impl`` / ``_lstm_backward_impl``
  explaining why the vmap path uses ``bmm`` for per-V weight reductions
  but keeps ``dgates_x_flat`` flat for the shared-weight ``dx`` matmul.
- Mark the ``B % V`` check as a defensive guardrail.
Adds ``test_*_module_scan_vs_triton_under_vmap`` for both GRU and LSTM.
The scan backend goes through standard PyTorch op dispatch and has no
custom vmap rule, so it serves as a ground-truth reference for our
hand-rolled flatten/unflatten path in the triton custom_op. Covers both
forward and ``vmap(grad(loss))`` against the same shared-weight inputs.
The ``custom_op`` family (``torch.library.custom_op`` / ``register_fake`` /
``register_autograd``) is the only autograd entry point we ship now; the
``_GRUFn`` / ``_LSTMFn`` ``autograd.Function`` mirrors only ran on
PyTorch < 2.4 builds, where the backend never advanced past prototype
anyway. ``_check_triton_available`` now also requires the custom_op API
so older PyTorch / Triton routes cleanly to scan/pad. Top-level
``gru_triton`` / ``lstm_triton`` raise a descriptive ``RuntimeError`` if
called when the backend is unavailable.

Net -149 LoC from the PR diff.
# Conflicts:
#	torchrl/modules/tensordict_module/_rnn_triton.py
#	torchrl/modules/tensordict_module/rnn.py
PyTorch 2.13 nightlies ship ``torch.library.register_autograd`` in a state
where the auto-generated ``autograd.Function`` lacks ``setup_context``,
breaking ``vmap(grad(custom_op_call(...)))`` with:

    RuntimeError: ... must override the setup_context staticmethod ...

The same nightlies also assert ``False != True`` inside
``torch._higher_order_ops.scan`` when called through ``vmap(grad(...))``.
Both failures are upstream, not bugs in this PR.

Probe once at collection by trying a tiny ``vmap(grad(gru_triton(...)))``
call; skip the four affected tests when the probe fails. Forward-only
``vmap`` coverage in the same tests remains unconditional.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Benchmarks rl/benchmark changes CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Integrations/torch_geometric Integrations Modules Performance Performance issue or suggestion for improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant