Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,61 @@ def make_env():
root_shifted = ref_data.get(k)[..., 1:, :]
torch.testing.assert_close(ref_next[mask], root_shifted[mask])

@pytest.mark.parametrize("use_buffers", [True, False])
def test_fake_tensordict_single_matches_iter(self, use_buffers):
"""``Collector.fake_tensordict()`` mirrors the shape and keys of a real batch."""

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
use_buffers=use_buffers,
)
try:
fake = c.fake_tensordict()
torch.manual_seed(0)
real = next(iter(c))
assert fake.batch_size == real.batch_size, (
fake.batch_size,
real.batch_size,
)
assert fake.names == real.names
fake_keys = sorted(map(str, fake.keys(True, True)))
real_keys = sorted(map(str, real.keys(True, True)))
assert fake_keys == real_keys, set(real_keys) ^ set(fake_keys)
for key, val in fake.items(True, True):
if (
val.dtype in (torch.bool, torch.uint8)
or not val.is_floating_point()
):
continue
assert not val.any(), f"{key} is not zeroed"
finally:
c.shutdown()

def test_fake_tensordict_multi_raises(self):
"""``MultiCollector.fake_tensordict()`` is intentionally not implemented."""

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = MultiCollector(
create_env_fn=[make_env, make_env],
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
sync=True,
)
try:
with pytest.raises(NotImplementedError, match="fake_tensordict"):
c.fake_tensordict()
finally:
c.shutdown()

@pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv])
def test_initial_obs_consistency(self, env_class, seed=1):
# non regression test on #938
Expand Down
19 changes: 19 additions & 0 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,25 @@ def _add_policy_outputs_to_fake_td(self, fake_td):
fake_td.set(key, policy_output.get(key))
return fake_td

def fake_tensordict(self) -> TensorDictBase:
"""Not implemented for multi-process collectors.

Honoring the multi-collector contract here would require either
creating an env in the main process (which defeats the purpose of
a multi-process collector — Isaac Lab / mujoco-mjx etc. can only
run in workers) or routing a request to a worker over the pipe
(which requires workers to be alive and adds protocol surface).
Neither is implemented; call :meth:`~torchrl.collectors.Collector.fake_tensordict`
on a single :class:`~torchrl.collectors.Collector` instead, or
build the template directly from the env spec.
"""
raise NotImplementedError(
f"{type(self).__name__}.fake_tensordict() is not implemented. "
"Use Collector.fake_tensordict() on a single-process collector "
"for storage / cudagraph warmup, or build the template from the "
"env spec directly."
)

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
Expand Down
27 changes: 27 additions & 0 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,33 @@ def _maybe_set_truncated(self, final_rollout):
)
return final_rollout

@torch.no_grad()
def fake_tensordict(self) -> TensorDictBase:
"""Return a zero-filled tensordict shaped like one batch from this collector.

The result mirrors what ``next(iter(collector))`` would yield:

- batch shape ``(*env.batch_size, frames_per_batch)`` with the last
dim named ``"time"``;
- env keys (observation / reward / done / terminated / truncated /
``is_init`` when an :class:`~torchrl.envs.InitTracker` is on the
env), policy out-keys, and ``("collector", "traj_ids")`` when
trajectory tracking is enabled;
- ``compact_obs=True`` exclusions applied;
- ``set_truncated=True`` last-step ``truncated``/``done`` masking
applied;
- ``postproc`` / ``split_trajs`` / private-key exclusion applied,
mirroring :meth:`_postproc`.

Intended for storage initialization and ``torch.compile`` /
cudagraph warmup without having to step the environment first.
"""
if getattr(self, "_final_rollout", None) is None:
self._maybe_make_final_rollout(make_rollout=True)
result = self._final_rollout.clone().zero_()
result = self._maybe_set_truncated(result)
return self._postproc(result)

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
Expand Down
Loading