Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 156 additions & 2 deletions test/objectives/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
import torch
from packaging import version

from tensordict import assert_allclose_td, TensorDict
from tensordict.nn import (
Expand Down Expand Up @@ -53,6 +54,8 @@
PENDULUM_VERSIONED,
)

_TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


class TestValues:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -411,7 +414,8 @@ def _build_shifted_test_td(self, *, with_internal_done: bool):
return td, obs_dim

@pytest.mark.parametrize("with_internal_done", [False, True])
def test_gae_shifted_compact_and_legacy(self, with_internal_done):
@pytest.mark.parametrize("compact_cat_dim", ["batch", "time"])
def test_gae_shifted_compact_and_legacy(self, with_internal_done, compact_cat_dim):
# Both shifted='compact' and shifted='legacy' must produce a valid
# advantage. 'legacy' must match shifted=False exactly. 'compact'
# is allowed a small boundary bias from copying V(obs[T-1]) at
Expand All @@ -424,7 +428,11 @@ def test_gae_shifted_compact_and_legacy(self, with_internal_done):
out_keys=["state_value"],
)
gae_compact = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted="compact"
gamma=0.9,
lmbda=0.95,
value_network=value_net,
shifted="compact",
compact_cat_dim=compact_cat_dim,
)
gae_legacy = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy"
Expand Down Expand Up @@ -455,6 +463,152 @@ def test_gae_shifted_true_deprecation_aliases_legacy(self):
adv_legacy = gae_legacy(td.copy())["advantage"]
torch.testing.assert_close(adv_true, adv_legacy)

@pytest.mark.skipif(
_TORCH_VERSION < version.parse("2.7"),
reason="GAE compact recurrent path uses torch.vmap chunked semantics that fall "
"back to _pseudo_vmap on torch<2.7 (NotImplementedError).",
)
@pytest.mark.parametrize("module", ["lstm", "gru"])
@pytest.mark.parametrize("compact_cat_dim", ["batch", "time"])
def test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shape(
self, module, compact_cat_dim
):
# Isaac-shaped regression test: recurrent value network, multi-trajectory
# rollout with truncations every `episode_len` steps (never terminations),
# and ``compact_obs=False`` semantics — ``("next", obs)`` is populated
# everywhere, in particular at internal-done positions where it carries
# the true pre-reset terminal observation (not the post-reset first obs
# of the new episode).
#
# Under these conditions shifted="compact" must match shifted=False
# to within a small tolerance. The compact path currently builds
# ``data_in = [root_obs[0:T], boundary_obs]`` and reads
# ``value_[t] = V(data_in[t+1])``; for ``t < T-1`` that is
# ``V(root_obs[t+1])``, which at internal-done positions is the
# **post-reset** obs rather than ``("next", obs)[t]``. The
# boundary-override mechanism in ``_call_value_net_compact`` only fills
# the rollout-edge slot, leaving internal-done positions corrupted.
# GAE then bootstraps with ``(1 - terminated)`` (truncations are
# *not* masked), so the wrong ``next_state_value`` propagates straight
# into the value target / advantage.
#
# See ``examples/collectors/isaaclab_rnn_ppo_memory.py`` and
# ``torchrl/objectives/value/advantages.py:_call_value_net_compact``.
torch.manual_seed(0)
B, T, obs_dim, hidden = 4, 16, 6, 8
episode_len = 4 # internal truncation every 4 steps
g = torch.Generator(device="cpu").manual_seed(0)
all_obs = torch.randn(B, T + 1, obs_dim, generator=g)
obs = all_obs[:, :T].clone()
next_obs = all_obs[:, 1:].clone()
done = torch.zeros(B, T, 1, dtype=torch.bool)
for t in range(episode_len - 1, T, episode_len):
done[:, t, 0] = True
if t < T - 1:
# Decouple next_obs[t] from obs[t+1]: env returned the true
# truncation obs, then auto-reset gave a fresh obs[t+1].
next_obs[:, t] = torch.randn(B, obs_dim, generator=g)
# Isaac-Ant only ever truncates (max_episode_steps); never terminates.
terminated = torch.zeros_like(done)
truncated = done.clone()
is_init = torch.zeros(B, T, 1, dtype=torch.bool)
is_init[:, 0, 0] = True
is_init[:, 1:][done[:, :-1]] = True
next_is_init = done.clone()
reward = torch.randn(B, T, 1, generator=g) * 0.1
td = TensorDict(
{
"observation": obs,
"is_init": is_init,
"next": TensorDict(
{
"observation": next_obs,
"reward": reward,
"done": done,
"terminated": terminated,
"truncated": truncated,
"is_init": next_is_init,
},
[B, T],
),
},
[B, T],
)

if module == "lstm":
recurrent_module = LSTMModule(
input_size=obs_dim,
hidden_size=hidden,
num_layers=1,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
python_based=True,
recurrent_backend="pad",
dropout=0,
)
else:
recurrent_module = GRUModule(
input_size=obs_dim,
hidden_size=hidden,
num_layers=1,
in_keys=["observation", "rs_h"],
out_keys=["intermediate", ("next", "rs_h")],
python_based=True,
recurrent_backend="pad",
dropout=0,
)
recurrent_module.eval()
value_net = Seq(
recurrent_module,
Mod(
nn.Linear(hidden, 1), in_keys=["intermediate"], out_keys=["state_value"]
),
)

gae_unshifted = GAE(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
shifted=False,
deactivate_vmap=True,
average_gae=False,
)
gae_compact = GAE(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
shifted="compact",
compact_cat_dim=compact_cat_dim,
deactivate_vmap=False,
average_gae=False,
)
with set_recurrent_mode(True), torch.no_grad():
adv_unshifted = gae_unshifted(td.clone())["advantage"]
adv_compact = gae_compact(td.clone())["advantage"]
# Tolerance is generous because the recurrent value net has its own
# set of mild approximations (legacy/False stack-and-vmap; compact
# single-call with boundary overrides). The bound here is the level
# at which we have empirically observed the Isaac PPO run diverge
# from the shifted=False baseline; values above ~5% mean-rel-err
# corresponded to a ~20% relative reward shortfall at iter 1000 on
# Isaac-Ant. See the wandb runs cited above.
mean_abs_diff = (adv_compact - adv_unshifted).abs().mean()
mean_unshifted_mag = adv_unshifted.abs().mean().clamp_min(1e-6)
rel = mean_abs_diff / mean_unshifted_mag
assert rel < 0.05, (
f"shifted='compact' advantage diverges from shifted=False by "
f"mean rel-err={float(rel):.4f} on the Isaac-shaped fixture. "
"This indicates the compact path's _call_value_net_compact is "
"not overriding internal-done positions of `data_in` with the "
"env-returned `('next', obs)` even when it is populated, so the "
"bootstrap value at every truncation step is computed against "
"the post-reset observation instead of the true truncation "
"observation. Bootstraps for truncations are not masked by "
"GAE's (1 - terminated) factor on Isaac-Ant (where every "
"episode boundary is a truncation), so the bias propagates "
"into the value target."
)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99])
@pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99])
Expand Down
Loading
Loading