Skip to content

[Feature] Collector.fake_tensordict() / MultiCollector.fake_tensordict()#3761

Closed
vmoens wants to merge 1 commit into
gh/vmoens/277/basefrom
gh/vmoens/277/head
Closed

[Feature] Collector.fake_tensordict() / MultiCollector.fake_tensordict()#3761
vmoens wants to merge 1 commit into
gh/vmoens/277/basefrom
gh/vmoens/277/head

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 15, 2026

Stack from ghstack (oldest at bottom):

Public method that returns a zero-filled tensordict shaped exactly like
one batch yielded by the collector, useful for storage initialization
and torch.compile / cudagraph warmup without having to step the env
or spin up the worker processes first.

On Collector (single):

  • Reuses the existing _final_rollout template; builds it lazily via
    _maybe_make_final_rollout(make_rollout=True) even when
    use_buffers=False so the public API is consistent.
  • Mirrors the rollout post-pipeline: _maybe_attach_final_obs,
    _maybe_set_truncated, then _postproc (which runs
    split_trajectories, the user postproc, and private-key
    exclusion).
  • Result: env keys + policy out-keys + ("collector", "traj_ids"),
    compact_obs exclusions and final_obs UnbatchedTensor
    leaves applied, last dim named "time".

On MultiCollector:

  • Builds a per-worker fake from create_env_fn[0] (mirroring the
    legacy replay-buffer init path), applies _add_policy_outputs_to_fake_td,
    expands to (*env.batch_size, frames_per_worker), refines "time".
  • For MultiSyncCollector, stacks num_workers copies along dim 0
    (or concatenates along cat_results when an integer was provided);
    for MultiAsyncCollector, returns a single worker's shape (async
    yields one batch at a time).
  • Applies split_trajs / postproc / private-key exclusion to
    match the iterator pipeline.

Tests pin: shape / names / keys / zero-fill parity between
fake_tensordict() and next(iter(collector)) (with and without
buffers); compact_obs drops ("next", obs) and final_obs
attaches ("final", obs) as UnbatchedTensor; multi-sync stacks
along worker dim 0.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3761

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

"""
# Build / borrow one env to read fake_tensordict, compact-obs leaf
# keys, and final-obs leaf shapes from.
env_fn = self.create_env_fn[0]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of this implementation.
We are creating an env in the main process which defies the purpose of the MultiCollector.
If we cannot get the fake data from the inner collector easily, we should just raise a NotImplementedError. But we should not pretend we're doing that (which is the right thing) and do it via a ton of custom code on the main process.

@vmoens vmoens closed this May 15, 2026
@vmoens vmoens deleted the gh/vmoens/277/head branch May 15, 2026 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Collectors Integrations/torch_geometric Integrations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant