Fix pre-trained policy action grad leak#5923
Conversation
Run low-level policy inference without autograd and detach the output before copying it into the persistent action buffer. This keeps Warp-facing action tensors usable when environments are stepped outside inference_mode.
Greptile SummaryThis PR fixes a latent autograd leak in
Confidence Score: 4/5Safe to merge; the fix is targeted and correct, addressing a real Warp compatibility breakage when environments are stepped outside inference mode. The two-line change is well-scoped and the approach is sound — wrapping the policy call in source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py — the boundary of the Important Files Changed
Sequence DiagramsequenceDiagram
participant Env as ManagerBasedRLEnv
participant PAT as PreTrainedPolicyAction
participant ObsMgr as ObservationManager
participant Policy as TorchScript Policy
participant LLA as LowLevelActionTerm
Env->>PAT: apply_actions()
alt "counter % low_level_decimation == 0"
PAT->>ObsMgr: compute_group("ll_policy")
ObsMgr-->>PAT: low_level_obs (may require_grad)
Note over PAT: torch.no_grad() context
PAT->>Policy: forward(low_level_obs)
Policy-->>PAT: output (no grad in no_grad ctx)
PAT->>PAT: .detach() then copy into low_level_actions[:]
Note over PAT: low_level_actions.requires_grad == False
PAT->>LLA: process_actions(low_level_actions)
end
PAT->>LLA: apply_actions()
Reviews (1): Last reviewed commit: "Fix pre-trained policy action grad leak" | Re-trigger Greptile |
There was a problem hiding this comment.
Code Review Summary
PR: Fix pre-trained policy action grad leak
Risk: Low | Impact: Bug fix — prevents Warp __cuda_array_interface__ rejection
Overall Assessment
Clean, minimal fix that correctly addresses a real gradient-leak bug. The change is well-scoped and the failure mode is clearly explained in the PR description. The two-line change is easy to reason about.
Findings
1. Redundant .detach() under torch.no_grad() — Nitpick (Intentional Belt-and-Suspenders)
File: source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py:99
Under a torch.no_grad() context, the policy output will already have requires_grad=False and no grad_fn. The .detach() call is therefore a no-op in normal execution. However, keeping it as a defensive measure is a reasonable choice — it guards against future refactors that might move the call outside the no_grad() block.
Verdict: Acceptable as-is. No action needed, just noting the redundancy for documentation purposes.
2. Observation tensor may retain computation graph — Low Severity
File: source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py:98
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")If any observation term in the low-level observation group performs differentiable operations (e.g., accesses tensor data that participates in the outer training loop's graph), low_level_obs could carry a computation graph. While this doesn't cause the Warp failure (since the result is detached), it means the graph stays alive until low_level_obs goes out of scope at the end of apply_actions().
This is unlikely to cause issues in practice (the observations are typically from sensor data or simple computations), but for completeness, wrapping the observation computation inside the no_grad() block would eliminate any graph retention:
with torch.no_grad():
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
self.low_level_actions[:] = self.policy(low_level_obs).detach()Severity: Low — no functional impact in current usage, but a minor memory/performance consideration.
3. No automated regression test — Medium Severity
The PR checklist indicates no test was added. While the manual verification described in the PR body is thorough, this class of bug (gradient leaks that only manifest under specific execution contexts) is prone to re-introduction. A minimal regression test that:
- Instantiates the action term with a simple JIT policy
- Calls
apply_actions()outsidetorch.inference_mode() - Asserts
self.low_level_actions.requires_grad == False
…would provide lasting protection. That said, I recognize this module (contrib/navigation) may lack test infrastructure, making this a follow-up rather than a blocker.
4. Changelog fragment is well-formed ✓
The changelog fragment follows the project convention correctly with proper RST formatting and cross-reference.
Summary
| Category | Status |
|---|---|
| Correctness | ✅ Fix is correct and minimal |
| Safety | ✅ No risk of breaking existing behavior |
| Performance | ✅ no_grad() actually improves perf by avoiding graph construction |
| Testing | |
| Documentation | ✅ Changelog fragment included |
Recommendation: This PR is ready to merge. The fix is correct, well-motivated, and low-risk. The only suggestion (moving compute_group inside the no_grad() block) is a minor optimization, not a blocker.
| low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") | ||
| self.low_level_actions[:] = self.policy(low_level_obs) | ||
| with torch.no_grad(): | ||
| self.low_level_actions[:] = self.policy(low_level_obs).detach() |
There was a problem hiding this comment.
The observation computation on line 98 happens outside the
no_grad() block, meaning low_level_obs may itself be a gradient-tracked tensor. While the output of the policy forward pass is already protected by the context manager, moving the observation computation inside the block ensures no intermediate observation tensors accumulate in the autograd graph unnecessarily and avoids any overhead from gradient bookkeeping for these inputs.
| low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") | |
| self.low_level_actions[:] = self.policy(low_level_obs) | |
| with torch.no_grad(): | |
| self.low_level_actions[:] = self.policy(low_level_obs).detach() | |
| with torch.no_grad(): | |
| low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") | |
| self.low_level_actions[:] = self.policy(low_level_obs).detach() |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if self._counter % self.cfg.low_level_decimation == 0: | ||
| low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") | ||
| self.low_level_actions[:] = self.policy(low_level_obs) | ||
| with torch.no_grad(): |
There was a problem hiding this comment.
Are we stepping the simulation at any point that is not already wrapped in torch.no_inference() or torch.no_grad() ? IMO this should not be necessary
Description
Fixes a latent autograd leak in
PreTrainedPolicyAction.apply_actions. When environments are stepped outsidetorch.inference_mode(), the low-level policy output can require gradients; copying that output intoself.low_level_actionsmakes the persistent action buffer require gradients as well. Warp later rejects that tensor when accessing__cuda_array_interface__.This runs the low-level policy under
torch.no_grad()and detaches its output before copying it into the action buffer. No new dependencies are required.@pascal-roth FYI for upstream review.
Fixes: N/A
Type of change
Screenshots
N/A.
Checklist
pre-commitchecks with./isaaclab.sh --formatsource/<pkg>/changelog.d/for every touched package (do not editCHANGELOG.rstor bumpextension.toml— CI handles that)CONTRIBUTORS.mdor my name already exists thereVerification
./isaaclab.sh -fwas run first. All non-LFS hooks passed, butcheck-git-lfs-pointersfailed becausegit-lfsis not installed in this environment.SKIP=check-git-lfs-pointers ./isaaclab.sh -fpassed.low_level_actions.requires_grad,low_level_actions.grad_fn, and the downstream action tensor remain detached afterapply_actions().