Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,11 @@ def test_a2c_speed(

loss = A2CLoss(actor_network=actor, critic_network=critic)
advantage = GAE(
value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
value_network=critic,
gamma=0.99,
lmbda=0.95,
shifted="legacy",
device=device,
)
advantage(td)
loss(td)
Expand Down Expand Up @@ -949,7 +953,11 @@ def test_ppo_speed(

loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
advantage = GAE(
value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
value_network=critic,
gamma=0.99,
lmbda=0.95,
shifted="legacy",
device=device,
)
advantage(td)
loss(td)
Expand Down Expand Up @@ -1054,7 +1062,11 @@ def test_reinforce_speed(

loss = ReinforceLoss(actor_network=actor, critic_network=critic)
advantage = GAE(
value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
value_network=critic,
gamma=0.99,
lmbda=0.95,
shifted="legacy",
device=device,
)
advantage(td)
loss(td)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/reference/data_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Two main patterns coexist in TorchRL:
trajectory structure (recurrent modules under
:func:`~torchrl.modules.set_recurrent_mode`,
:class:`~torchrl.data.SliceSampler`, value estimators in
``single_call=True`` mode) consumes this layout natively.
explicit shifted modes such as ``shifted="compact"`` or
``shifted="legacy"``) consumes this layout natively.

The rest of this page walks through the building blocks.

Expand Down
4 changes: 1 addition & 3 deletions examples/collectors/isaaclab_rnn_ppo_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ def main() -> None:
raise ValueError("--num-envs must be divisible by --num-collectors.")
if args.compile_update or args.cudagraph_update:
torch._dynamo.config.capture_scalar_outputs = True
gae_shifted: bool | str = (
False if args.gae_shifted == "false" else args.gae_shifted
)
gae_shifted: bool | str = False if args.gae_shifted == "false" else args.gae_shifted

torch.manual_seed(args.seed)
torch.set_float32_matmul_precision("high")
Expand Down
6 changes: 5 additions & 1 deletion examples/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,11 @@ def make_reward_model(reward_model_cfg, sys_cfg):

def make_loss(actor, critic, critic_head):
advantage = GAE(
value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True, shifted=True
value_network=critic,
gamma=0.99,
lmbda=0.95,
average_gae=True,
shifted="legacy",
)
loss_fn = ClipPPOLoss(actor, critic_head)
return loss_fn, advantage
Expand Down
2 changes: 1 addition & 1 deletion knowledge_base/ISAACLAB.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ from tensordict import TensorDict
collector.update_policy_weights_(weights=TensorDict.from_module(actor).data)
```

With compact rollout data, prefer `shifted=True` value estimation so the PPO
With compact rollout data, prefer `shifted="compact"` value estimation so the PPO
batch does not need `("next", "policy")` rehydration. If a backend requires
canonical strides, `td.contiguous()` and `td.clone()` may not be enough for
size-1 dimensions; `torch.empty_like(td).update_(td)` is the stronger
Expand Down
21 changes: 18 additions & 3 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
the version under which each individual frame was produced, not as a batch-level
label.

For multi-process collectors, the ``"policy_version"`` entries in the
collected tensordict are produced by worker-local transforms and are the
source of truth for data provenance. The parent collector's
:attr:`policy_version` property exposes only the parent-side tracker state
and should not be used as a label for a returned batch.

The recommended path is ``track_policy_version=True``: let the collector own
the transform. Passing a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion`
instance directly is reserved for advanced use cases that wire up a
Expand Down Expand Up @@ -1986,19 +1992,28 @@ def increment_version(self):

@property
def policy_version(self) -> str | int | None:
"""The current policy version."""
"""The parent-side policy version.

For multi-process collectors, worker-local
:class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion`
transforms write the per-frame ``"policy_version"`` values in returned
batches. Those tensor entries are the source of truth for collected
data; this property is only the parent-side tracker state.
"""
if not hasattr(self.policy_version_tracker, "version"):
return None
return self.policy_version_tracker.version

def get_policy_version(self) -> str | int | None:
"""Get the current policy version.
"""Get the parent-side policy version.

This method exists to support remote calls in Ray actors, since properties
cannot be accessed directly through Ray's RPC mechanism.

Returns:
The current version number (int) or UUID (str), or None if version tracking is disabled.
The parent-side version number (int) or UUID (str), or ``None`` if
version tracking is disabled. For collected data, prefer the
per-frame ``"policy_version"`` tensor in returned batches.
"""
return self.policy_version

Expand Down
10 changes: 9 additions & 1 deletion torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,15 @@ class Collector(BaseCollector):
keys can be re-hydrated at sampling time with
:class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor`
when consuming a :class:`~torchrl.data.SliceSampler`-backed replay
buffer. Defaults to ``False``.
buffer.

``compact_obs=True`` composes cleanly with
:class:`~torchrl.objectives.value.advantages.GAE` configured with
``shifted="compact"``: the compact shifted path can run the
on-policy advantage pass without rehydrating every per-step
``("next", "observation")`` mirror. For vectorized environments
with large observations this is typically a sizeable GPU-memory
win at near-zero CPU cost. Defaults to ``False``.

Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down
18 changes: 10 additions & 8 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,14 +1690,16 @@ class GAE(ValueEstimatorBase):

.. note:: GAE can be used with value networks that rely on recurrent neural networks, provided that the
init markers (`"is_init"`) and terminated / truncated markers are properly set.
If `shifted=True`, the trajectory batch will be flattened and the last step of each trajectory will
be placed within the flat tensordict after the last step from the root, such that each trajectory has
`T+1` elements. If `shifted=False`, the root and `"next"` trajecotries will be stacked and the value
network will be called with `vmap` over the stack of trajectories. Because RNNs require fair amount of
control flow, they are currently not compatible with `torch.vmap` and, as such, the `deactivate_vmap` option
must be turned on in these cases.
Similarly, if `shifted=False`, the `"is_init"` entry of the root tensordict will be copied onto the
`"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
With ``shifted="legacy"``, the trajectory batch is flattened and the next state of each done step is
interleaved after its root state, giving exact ``V(next_obs)`` values at the cost of a data-dependent
shape. With ``shifted="compact"``, root and next streams are concatenated into a constant-shape
batch, which is friendlier to ``torch.compile`` and scan-style recurrent backends. If ``shifted=False``,
the root and ``"next"`` trajectories are stacked and the value network is called with ``vmap`` over the
stack of trajectories. Because RNNs require a fair amount of control flow, they are currently not
compatible with ``torch.vmap`` and, as such, the ``deactivate_vmap`` option must be turned on in these
cases. Similarly, if ``shifted=False``, the ``"is_init"`` entry of the root tensordict will be copied
onto the ``"is_init"`` of the ``"next"`` entry, such that trajectories are well separated both for root
and ``"next"`` data.
"""

value_network: TensorDictModule | None
Expand Down
Loading