Skip to content

Fix pre-trained policy action grad leak#5923

Open
AntoineRichard wants to merge 1 commit into
isaac-sim:developfrom
AntoineRichard:antoine/fix-pretrained-policy-action-grad
Open

Fix pre-trained policy action grad leak#5923
AntoineRichard wants to merge 1 commit into
isaac-sim:developfrom
AntoineRichard:antoine/fix-pretrained-policy-action-grad

Conversation

@AntoineRichard
Copy link
Copy Markdown
Collaborator

Description

Fixes a latent autograd leak in PreTrainedPolicyAction.apply_actions. When environments are stepped outside torch.inference_mode(), the low-level policy output can require gradients; copying that output into self.low_level_actions makes the persistent action buffer require gradients as well. Warp later rejects that tensor when accessing __cuda_array_interface__.

This runs the low-level policy under torch.no_grad() and detaches its output before copying it into the action buffer. No new dependencies are required.

@pascal-roth FYI for upstream review.

Fixes: N/A

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Screenshots

N/A.

Checklist

  • I have read and understood the contribution guidelines
  • I have run the pre-commit checks with ./isaaclab.sh --format
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I have added a changelog fragment under source/<pkg>/changelog.d/ for every touched package (do not edit CHANGELOG.rst or bump extension.toml — CI handles that)
  • I have added my name to the CONTRIBUTORS.md or my name already exists there

Verification

  • ./isaaclab.sh -f was run first. All non-LFS hooks passed, but check-git-lfs-pointers failed because git-lfs is not installed in this environment.
  • SKIP=check-git-lfs-pointers ./isaaclab.sh -f passed.
  • Focused inline check with normal grad mode enabled confirmed the policy output requires grad before the action term runs, while low_level_actions.requires_grad, low_level_actions.grad_fn, and the downstream action tensor remain detached after apply_actions().
  • No automated test was added for this change.

Run low-level policy inference without autograd and detach the output before copying it into the persistent action buffer. This keeps Warp-facing action tensors usable when environments are stepped outside inference_mode.
@github-actions github-actions Bot added bug Something isn't working isaac-lab Related to Isaac Lab team labels Jun 2, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 2, 2026

Greptile Summary

This PR fixes a latent autograd leak in PreTrainedPolicyAction.apply_actions where the TorchScript policy output could carry gradient tracking into the low_level_actions buffer when the environment is stepped outside torch.inference_mode(), causing Warp to reject the tensor.

  • The policy forward pass is now wrapped in torch.no_grad() and .detach() is called on the result before the in-place copy, preventing any gradient graph from being attached to the persistent action buffer.
  • The observation computation (compute_group) is still performed outside the no_grad() block; moving it inside would extend the protection to intermediate observation tensors as well (see inline comment).

Confidence Score: 4/5

Safe to merge; the fix is targeted and correct, addressing a real Warp compatibility breakage when environments are stepped outside inference mode.

The two-line change is well-scoped and the approach is sound — wrapping the policy call in no_grad() and detaching before the in-place copy cleanly breaks the autograd link. The only gap is that compute_group runs just outside the context, leaving observation tensors briefly in grad mode, but this does not affect the buffer that was actually causing Warp to fail.

source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py — the boundary of the no_grad() block could be widened to also cover the observation computation.

Important Files Changed

Filename Overview
source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py Wraps the low-level policy forward pass in torch.no_grad() and adds .detach() on the output to prevent gradient contamination of the action buffer; observation computation before the block remains in grad mode
source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst New changelog fragment accurately describing the autograd detach fix for PreTrainedPolicyAction

Sequence Diagram

sequenceDiagram
    participant Env as ManagerBasedRLEnv
    participant PAT as PreTrainedPolicyAction
    participant ObsMgr as ObservationManager
    participant Policy as TorchScript Policy
    participant LLA as LowLevelActionTerm

    Env->>PAT: apply_actions()
    alt "counter % low_level_decimation == 0"
        PAT->>ObsMgr: compute_group("ll_policy")
        ObsMgr-->>PAT: low_level_obs (may require_grad)
        Note over PAT: torch.no_grad() context
        PAT->>Policy: forward(low_level_obs)
        Policy-->>PAT: output (no grad in no_grad ctx)
        PAT->>PAT: .detach() then copy into low_level_actions[:]
        Note over PAT: low_level_actions.requires_grad == False
        PAT->>LLA: process_actions(low_level_actions)
    end
    PAT->>LLA: apply_actions()
Loading

Reviews (1): Last reviewed commit: "Fix pre-trained policy action grad leak" | Re-trigger Greptile

Copy link
Copy Markdown

@isaaclab-review-bot isaaclab-review-bot Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review Summary

PR: Fix pre-trained policy action grad leak
Risk: Low | Impact: Bug fix — prevents Warp __cuda_array_interface__ rejection


Overall Assessment

Clean, minimal fix that correctly addresses a real gradient-leak bug. The change is well-scoped and the failure mode is clearly explained in the PR description. The two-line change is easy to reason about.


Findings

1. Redundant .detach() under torch.no_grad() — Nitpick (Intentional Belt-and-Suspenders)

File: source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py:99

Under a torch.no_grad() context, the policy output will already have requires_grad=False and no grad_fn. The .detach() call is therefore a no-op in normal execution. However, keeping it as a defensive measure is a reasonable choice — it guards against future refactors that might move the call outside the no_grad() block.

Verdict: Acceptable as-is. No action needed, just noting the redundancy for documentation purposes.


2. Observation tensor may retain computation graph — Low Severity

File: source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py:98

low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")

If any observation term in the low-level observation group performs differentiable operations (e.g., accesses tensor data that participates in the outer training loop's graph), low_level_obs could carry a computation graph. While this doesn't cause the Warp failure (since the result is detached), it means the graph stays alive until low_level_obs goes out of scope at the end of apply_actions().

This is unlikely to cause issues in practice (the observations are typically from sensor data or simple computations), but for completeness, wrapping the observation computation inside the no_grad() block would eliminate any graph retention:

with torch.no_grad():
    low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
    self.low_level_actions[:] = self.policy(low_level_obs).detach()

Severity: Low — no functional impact in current usage, but a minor memory/performance consideration.


3. No automated regression test — Medium Severity

The PR checklist indicates no test was added. While the manual verification described in the PR body is thorough, this class of bug (gradient leaks that only manifest under specific execution contexts) is prone to re-introduction. A minimal regression test that:

  1. Instantiates the action term with a simple JIT policy
  2. Calls apply_actions() outside torch.inference_mode()
  3. Asserts self.low_level_actions.requires_grad == False

…would provide lasting protection. That said, I recognize this module (contrib/navigation) may lack test infrastructure, making this a follow-up rather than a blocker.


4. Changelog fragment is well-formed ✓

The changelog fragment follows the project convention correctly with proper RST formatting and cross-reference.


Summary

Category Status
Correctness ✅ Fix is correct and minimal
Safety ✅ No risk of breaking existing behavior
Performance no_grad() actually improves perf by avoiding graph construction
Testing ⚠️ No regression test (acceptable for contrib, recommended as follow-up)
Documentation ✅ Changelog fragment included

Recommendation: This PR is ready to merge. The fix is correct, well-motivated, and low-risk. The only suggestion (moving compute_group inside the no_grad() block) is a minor optimization, not a blocker.

Comment on lines 98 to +100
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
self.low_level_actions[:] = self.policy(low_level_obs)
with torch.no_grad():
self.low_level_actions[:] = self.policy(low_level_obs).detach()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The observation computation on line 98 happens outside the no_grad() block, meaning low_level_obs may itself be a gradient-tracked tensor. While the output of the policy forward pass is already protected by the context manager, moving the observation computation inside the block ensures no intermediate observation tensors accumulate in the autograd graph unnecessarily and avoids any overhead from gradient bookkeeping for these inputs.

Suggested change
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
self.low_level_actions[:] = self.policy(low_level_obs)
with torch.no_grad():
self.low_level_actions[:] = self.policy(low_level_obs).detach()
with torch.no_grad():
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
self.low_level_actions[:] = self.policy(low_level_obs).detach()

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

if self._counter % self.cfg.low_level_decimation == 0:
low_level_obs = self._low_level_obs_manager.compute_group("ll_policy")
self.low_level_actions[:] = self.policy(low_level_obs)
with torch.no_grad():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we stepping the simulation at any point that is not already wrapped in torch.no_inference() or torch.no_grad() ? IMO this should not be necessary

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working isaac-lab Related to Isaac Lab team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants