Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions test/objectives/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,89 @@ def test_nan_next_obs_at_done_is_safe(self, estimator_cls, kwargs, shifted):
torch.testing.assert_close(vt[..., -1, :], reward[..., -1, :])
torch.testing.assert_close(adv[..., -1, :], reward[..., -1, :] - v_s)

@pytest.mark.parametrize(
"estimator_cls,kwargs",
[
(TD0Estimator, {"gamma": 0.9}),
(TD1Estimator, {"gamma": 0.9}),
(TDLambdaEstimator, {"gamma": 0.9, "lmbda": 0.95}),
(GAE, {"gamma": 0.9, "lmbda": 0.95}),
],
)
def test_missing_next_obs_compact_shifted_is_safe(self, estimator_cls, kwargs):
torch.manual_seed(0)
value_net = TensorDictModule(
nn.Linear(3, 1, bias=False),
in_keys=["obs"],
out_keys=["state_value"],
)
B, T, F = 2, 5, 3
obs = torch.randn(B, T, F)
done = torch.zeros(B, T, 1, dtype=torch.bool)
done[:, 2] = True
done[:, -1] = True
reward = torch.ones(B, T, 1)
td_compact = TensorDict(
{
"obs": obs,
"next": {
"reward": reward,
"done": done.clone(),
"terminated": done.clone(),
},
},
[B, T],
)

next_obs = torch.empty_like(obs)
next_obs[:, :-1] = obs[:, 1:]
next_obs[:, -1] = float("nan")
next_obs[done.expand_as(next_obs)] = float("nan")
td_reference = td_compact.clone()
td_reference["next", "obs"] = next_obs

est = estimator_cls(**kwargs, value_network=value_net, shifted=True)
td_actual = td_compact.clone()
actual = est(td_actual)
expected = est(td_reference.clone())

assert td_actual.get(("next", "obs"), default=None) is None
torch.testing.assert_close(actual["advantage"], expected["advantage"])
torch.testing.assert_close(actual["value_target"], expected["value_target"])
assert torch.isfinite(actual["advantage"]).all()
assert torch.isfinite(actual["value_target"]).all()

def test_shifted_gae_accepts_noncanonical_strides(self):
torch.manual_seed(0)
value_net = TensorDictModule(
nn.Linear(3, 1, bias=False),
in_keys=["obs"],
out_keys=["state_value"],
)
B, T, F = 2, 5, 3
obs = torch.randn(B, T, F)
done = torch.zeros(B, T, 1, dtype=torch.bool)
done[:, -1] = True
reward = torch.ones(B, T, 1)
td = TensorDict(
{
"obs": obs,
"next": {
"obs": torch.randn(B, T, F),
"reward": reward,
"done": done.clone(),
"terminated": done.clone(),
},
},
[B, T],
).transpose(0, 1)

assert not td["obs"].is_contiguous()
est = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True)
out = est(td)
assert torch.isfinite(out["advantage"]).all()
assert torch.isfinite(out["value_target"]).all()

@pytest.mark.skipif(not _has_gym, reason="requires gym")
def test_gae_multi_done(self):

Expand Down Expand Up @@ -214,8 +297,12 @@ def test_gae_multi_done(self):

@pytest.mark.skipif(not _has_gym, reason="requires gym")
@pytest.mark.parametrize("module", ["lstm", "gru"])
def test_gae_recurrent(self, module):
# Checks that shifted=True and False provide the same result in GAE when an LSTM is used
@pytest.mark.parametrize("vectorized", [False, True])
def test_gae_recurrent(self, module, vectorized):
# Checks that shifted=True and False produce the same advantages
# when an RNN value net is used, across both vectorized and
# non-vectorized GAE — vectorized and shifted are orthogonal and
# should all agree.
env = SerialEnv(
2,
[
Expand Down Expand Up @@ -268,6 +355,7 @@ def test_gae_recurrent(self, module):
lmbda=0.99,
value_network=value_net,
shifted=True,
vectorized=vectorized,
)
with set_recurrent_mode(True):
r0 = gae_shifted(vals.copy())
Expand All @@ -278,6 +366,7 @@ def test_gae_recurrent(self, module):
lmbda=0.99,
value_network=value_net,
shifted=False,
vectorized=vectorized,
deactivate_vmap=True,
)
with pytest.raises(
Expand Down
163 changes: 105 additions & 58 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase):

@staticmethod
def _sanitize_next_obs_nan(
data: TensorDictBase, in_keys: list[NestedKey]
data: TensorDictBase,
in_keys: list[NestedKey],
) -> TensorDictBase:
"""Replace ``NaN`` entries in ``("next", k)`` with the corresponding root ``k``.

Expand Down Expand Up @@ -473,6 +474,23 @@ def _sanitize_next_obs_nan(
data.set(next_k, torch.where(nan_mask, root, nxt))
return data

@staticmethod
def _fill_missing_next_inputs(
next_data: TensorDictBase, root_data: TensorDictBase, in_keys: list[NestedKey]
) -> TensorDictBase:
copied = False
for key in in_keys:
if next_data.get(key, default=None) is not None:
continue
value = root_data.get(key, default=None)
if value is None:
continue
if not copied:
next_data = next_data.copy()
copied = True
next_data.set(key, value)
return next_data

def _call_value_nets(
self,
data: TensorDictBase,
Expand All @@ -488,6 +506,18 @@ def _call_value_nets(
if value_net is None:
value_net = self.value_network
in_keys = value_net.in_keys
if single_call:
try:
ndim = list(data.names).index("time") + 1
except ValueError:
if rl_warnings():
logger.warning(
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
)
ndim = data.ndim
else:
ndim = None
data = self._sanitize_next_obs_nan(data, in_keys)

def _call_value_net(data_in: TensorDictBase) -> torch.Tensor:
Expand All @@ -505,15 +535,6 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor:
# have T+1 elements (or, for a batch of N trajectories, we will have \Sum_{t=0}^{T-1} length_t + T
# elements). Then, we can feed that to our RNN which will understand which trajectory is which, pad the data
# accordingly and process each of them independently.
try:
ndim = list(data.names).index("time") + 1
except ValueError:
if rl_warnings():
logger.warning(
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
)
ndim = data.ndim
data_copy = data.copy()
# we are going to modify the done so let's clone it
done = data_copy["next", "done"].clone()
Expand All @@ -525,51 +546,52 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor:
truncated[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
data_copy["next", "done"] = done
data_copy["next", "truncated"] = truncated
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states
with data_copy.view(-1) as data_copy_view:
# Interleave next data when done
data_copy_select = data_copy_view.select(
*in_keys, value_key, strict=False
)
total_elts = (
data_copy_view.shape[0]
+ data_copy_view["next", "done"].sum().item()
)
data_in = data_copy_select.new_zeros((total_elts,))
# we can get the indices of non-done data by adding the shifted done cumsum to an arange
# traj = [0, 0, 0, 1, 1, 2, 2]
# arange = [0, 1, 2, 3, 4, 5, 6]
# done = [0, 0, 1, 0, 1, 0, 1]
# done_cs = [0, 0, 0, 1, 1, 2, 2]
# indices = [0, 1, 2, 4, 5, 7, 8]
done_view = data_copy_view["next", "done"]
if done_view.shape[-1] == 1:
done_view = done_view.squeeze(-1)
else:
done_view = done_view.any(-1)
done_cs = done_view.cumsum(0)
done_cs = torch.cat([done_cs.new_zeros((1,)), done_cs[:-1]], dim=0)
indices = torch.arange(done_cs.shape[0], device=done_cs.device)
indices = indices + done_cs
data_in[indices] = data_copy_select
# To get the indices of the extra data, we can mask indices with done_view and add 1
indices_interleaved = indices[done_view] + 1
# assert not set(indices_interleaved.tolist()).intersection(indices.tolist())
data_in[indices_interleaved] = (
data_copy_view[done_view]
.get("next")
.select(*in_keys, value_key, strict=False)
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states.
# Use reshape, not view: replay-buffer and memmap reads can expose non-canonical strides.
data_copy_view = data_copy.reshape(-1)
# Interleave next data when done
data_copy_select = data_copy_view.select(*in_keys, value_key, strict=False)
total_elts = (
data_copy_view.shape[0] + data_copy_view["next", "done"].sum().item()
)
data_in = data_copy_select.new_zeros((total_elts,))
# we can get the indices of non-done data by adding the shifted done cumsum to an arange
# traj = [0, 0, 0, 1, 1, 2, 2]
# arange = [0, 1, 2, 3, 4, 5, 6]
# done = [0, 0, 1, 0, 1, 0, 1]
# done_cs = [0, 0, 0, 1, 1, 2, 2]
# indices = [0, 1, 2, 4, 5, 7, 8]
done_view = data_copy_view["next", "done"]
if done_view.shape[-1] == 1:
done_view = done_view.squeeze(-1)
else:
done_view = done_view.any(-1)
done_cs = done_view.cumsum(0)
done_cs = torch.cat([done_cs.new_zeros((1,)), done_cs[:-1]], dim=0)
indices = torch.arange(done_cs.shape[0], device=done_cs.device)
indices = indices + done_cs
data_in[indices] = data_copy_select
# To get the indices of the extra data, we can mask indices with done_view and add 1
indices_interleaved = indices[done_view] + 1
# assert not set(indices_interleaved.tolist()).intersection(indices.tolist())
root_done_data = data_copy_view[done_view]
next_done_data = root_done_data.get("next").select(
*in_keys, value_key, strict=False
)
next_done_data = self._fill_missing_next_inputs(
next_done_data, root_done_data, in_keys
)
data_in[indices_interleaved] = next_done_data
if next_params is not None and next_params is not params:
raise ValueError(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if next_params is not None and next_params is not params:
raise ValueError(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if params is not None:
with params.to_module(value_net):
value_est = _call_value_net(data_in)
else:
if params is not None:
with params.to_module(value_net):
value_est = _call_value_net(data_in)
value, value_ = value_est[indices], value_est[indices + 1]
else:
value_est = _call_value_net(data_in)
value, value_ = value_est[indices], value_est[indices + 1]
value = value.view_as(done)
value_ = value_.view_as(done)
else:
Expand Down Expand Up @@ -634,7 +656,12 @@ class TD0Estimator(ValueEstimatorBase):
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
parameters are to be used). For recurrent policies or compact
rollouts, the input should contain long, contiguous trajectory
windows with valid boundary next states; short partial rollouts
that drop the final next observation can bias bootstrapping. In
that case, keep or reconstruct boundary next states, or use
``shifted=False``. Defaults to ``False``.
average_rewards (bool, optional): if ``True``, rewards will be standardized
before the TD is computed.
differentiable (bool, optional): if ``True``, gradients are propagated through
Expand Down Expand Up @@ -873,7 +900,12 @@ class TD1Estimator(ValueEstimatorBase):
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
parameters are to be used). For recurrent policies or compact
rollouts, the input should contain long, contiguous trajectory
windows with valid boundary next states; short partial rollouts
that drop the final next observation can bias bootstrapping. In
that case, keep or reconstruct boundary next states, or use
``shifted=False``. Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
Expand Down Expand Up @@ -1106,7 +1138,12 @@ class TDLambdaEstimator(ValueEstimatorBase):
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
parameters are to be used). For recurrent policies or compact
rollouts, the input should contain long, contiguous trajectory
windows with valid boundary next states; short partial rollouts
that drop the final next observation can bias bootstrapping. In
that case, keep or reconstruct boundary next states, or use
``shifted=False``. Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
Expand Down Expand Up @@ -1376,7 +1413,12 @@ class GAE(ValueEstimatorBase):
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
parameters are to be used). For recurrent policies or compact
rollouts, the input should contain long, contiguous trajectory
windows with valid boundary next states; short partial rollouts
that drop the final next observation can bias bootstrapping. In
that case, keep or reconstruct boundary next states, or use
``shifted=False``. Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
Expand Down Expand Up @@ -1761,7 +1803,12 @@ class VTrace(ValueEstimatorBase):
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
parameters are to be used). For recurrent policies or compact
rollouts, the input should contain long, contiguous trajectory
windows with valid boundary next states; short partial rollouts
that drop the final next observation can bias bootstrapping. In
that case, keep or reconstruct boundary next states, or use
``shifted=False``. Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
Expand Down
Loading