[Performance] Add compile integration for Triton RNN kernels#3740
Open
vmoens wants to merge 8 commits into
Open
[Performance] Add compile integration for Triton RNN kernels#3740vmoens wants to merge 8 commits into
vmoens wants to merge 8 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3740
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New FailuresAs of commit 5e750e8 with merge base 3df2f4a ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
- Extract the validate/flatten/fallback/unflatten scaffolding into ``_vmap_backward_via_flatten``; GRU and LSTM backward vmap rules only need a per-op ``_invoke`` closure that unpacks the flattened args, rebuilds ``shapes`` from them, and calls the impl. - Add a comment in ``_gru_backward_impl`` / ``_lstm_backward_impl`` explaining why the vmap path uses ``bmm`` for per-V weight reductions but keeps ``dgates_x_flat`` flat for the shared-weight ``dx`` matmul. - Mark the ``B % V`` check as a defensive guardrail.
Adds ``test_*_module_scan_vs_triton_under_vmap`` for both GRU and LSTM. The scan backend goes through standard PyTorch op dispatch and has no custom vmap rule, so it serves as a ground-truth reference for our hand-rolled flatten/unflatten path in the triton custom_op. Covers both forward and ``vmap(grad(loss))`` against the same shared-weight inputs.
The ``custom_op`` family (``torch.library.custom_op`` / ``register_fake`` / ``register_autograd``) is the only autograd entry point we ship now; the ``_GRUFn`` / ``_LSTMFn`` ``autograd.Function`` mirrors only ran on PyTorch < 2.4 builds, where the backend never advanced past prototype anyway. ``_check_triton_available`` now also requires the custom_op API so older PyTorch / Triton routes cleanly to scan/pad. Top-level ``gru_triton`` / ``lstm_triton`` raise a descriptive ``RuntimeError`` if called when the backend is unavailable. Net -149 LoC from the PR diff.
# Conflicts: # torchrl/modules/tensordict_module/_rnn_triton.py # torchrl/modules/tensordict_module/rnn.py
PyTorch 2.13 nightlies ship ``torch.library.register_autograd`` in a state
where the auto-generated ``autograd.Function`` lacks ``setup_context``,
breaking ``vmap(grad(custom_op_call(...)))`` with:
RuntimeError: ... must override the setup_context staticmethod ...
The same nightlies also assert ``False != True`` inside
``torch._higher_order_ops.scan`` when called through ``vmap(grad(...))``.
Both failures are upstream, not bugs in this PR.
Probe once at collection by trying a tiny ``vmap(grad(gru_triton(...)))``
call; skip the four affected tests when the probe fails. Forward-only
``vmap`` coverage in the same tests remains unconditional.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch.library.custom_opsotorch.compilesees opaque traceable operators.--cudagraphoption and document observed compile/cudagraph behavior.Testing
git diff --cached --checkbefore commitpython -m py_compile benchmarks/bench_gru_reset_backends.py torchrl/modules/tensordict_module/_rnn_triton.py torchrl/modules/tensordict_module/rnn.py test/test_tensordictmodules.pyPYTHONPATH=. pytest test/test_tensordictmodules.py -k "triton and (gru_module or lstm_module or custom_op_compile or custom_op_vmap)"2 passed, 29 skipped, 260 deselected; CUDA/Triton tests skipped on this machine.Notes