Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4085,6 +4085,163 @@ def make_env():
del collector


class TestPolicyVersion:
"""End-to-end checks for ``track_policy_version`` on data collectors.

The contract: when a collector is constructed with
``track_policy_version=True``, every collected frame must carry a
``policy_version`` key, and that value must bump exactly once per real
weight update (``update_policy_weights_()``), regardless of how many
iterations are pulled from the collector.
"""

class _Env(EnvBase):
def __init__(self, device="cpu"):
super().__init__(batch_size=(), device=device)
self.observation_spec = Composite(
observation=Unbounded(shape=(2,), device=device)
)
self.action_spec = Unbounded(shape=(2,), device=device)
self.reward_spec = Unbounded(shape=(1,), device=device)

def _step(self, td):
return TensorDict(
{
"observation": torch.zeros(2, device=self.device),
"reward": torch.zeros(1, device=self.device),
**self.full_done_spec.zero(),
},
(),
device=self.device,
)

def _reset(self, td=None):
return TensorDict(
{"observation": torch.zeros(2, device=self.device)},
(),
device=self.device,
)

def _set_seed(self, seed):
...

@staticmethod
def _make_policy():
return TensorDictModule(
nn.Linear(2, 2), in_keys=["observation"], out_keys=["action"]
)

def test_single_collector_bumps_on_update(self):
"""``SyncDataCollector`` bumps policy_version on each weight update."""
policy = self._make_policy()
collector = SyncDataCollector(
self._Env,
policy=policy,
total_frames=60,
frames_per_batch=10,
track_policy_version=True,
)
try:
it = iter(collector)
batch0 = next(it)
v0 = batch0["next", "policy_version"]
assert v0.dtype == torch.int64
# Version is constant within a batch (no update happened mid-batch).
assert (v0 == v0[0]).all()

# No update yet -> next batch keeps the same version.
batch1 = next(it)
assert (batch1["next", "policy_version"] == v0[0]).all()

collector.update_policy_weights_()
batch2 = next(it)
assert (batch2["next", "policy_version"] == v0[0] + 1).all()

# A second update bumps again, but a continue-without-update doesn't.
collector.update_policy_weights_()
batch3 = next(it)
assert (batch3["next", "policy_version"] == v0[0] + 2).all()
batch4 = next(it)
assert (batch4["next", "policy_version"] == v0[0] + 2).all()
finally:
collector.shutdown()

@pytest.mark.parametrize(
"collector_cls",
[
functools.partial(MultiSyncCollector, cat_results="stack"),
MultiAsyncCollector,
],
ids=["multi_sync", "multi_async"],
)
@pytest.mark.parametrize(
"weight_sync_scheme_cls",
[MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
ids=["mp", "shared_mem"],
)
def test_multi_collector_bumps_on_update(
self, collector_cls, weight_sync_scheme_cls
):
"""Worker-side policy_version follows real weight updates.

Regression for the case where the worker's ``PolicyVersion`` transform
was never incremented from the parent's ``update_policy_weights_()``,
leaving all worker batches tagged with version 0.
"""
policy = self._make_policy()
collector = collector_cls(
[self._Env, self._Env],
policy=policy,
frames_per_batch=20,
total_frames=200,
track_policy_version=True,
weight_sync_schemes={"policy": weight_sync_scheme_cls()},
)
try:
it = iter(collector)
batch0 = next(it)
v0 = batch0["next", "policy_version"]
# All workers start at the same initial version (0 by default).
v0_val = int(v0.flatten()[0].item())
assert (v0 == v0_val).all()

# Iterations without weight updates must not bump the version.
for _ in range(2):
batch = next(it)
assert (batch["next", "policy_version"] == v0_val).all(), (
f"Worker version drifted without weight update: "
f"{batch['next', 'policy_version']}"
)

collector.update_policy_weights_()
# The worker bumps once it has actually applied the new weights.
# In async mode, a batch already in flight at the time of the
# update may straddle the bump (some frames pre-, some post-).
# Drain until we see a batch fully at the bumped version, with
# a sane safety cap so we don't loop forever on a regression.
target = v0_val + 1
for _ in range(10):
batch = next(it)
if (batch["next", "policy_version"] == target).all():
break
else:
raise AssertionError(
f"Worker version did not reach {target} within 10 batches "
f"after a single update_policy_weights_(); last batch: "
f"{batch['next', 'policy_version']}"
)

# And no further bumps should occur on subsequent continues that
# are not preceded by an update.
batch_no_update = next(it)
assert (batch_no_update["next", "policy_version"] == target).all(), (
f"Worker version drifted past {target} without an update: "
f"{batch_no_update['next', 'policy_version']}"
)
finally:
collector.shutdown()


class TestAggregateReset:
def test_aggregate_reset_to_root(self):
# simple
Expand Down
7 changes: 7 additions & 0 deletions torchrl/collectors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,13 @@ def register_scheme_receiver(
context=self,
worker_idx=self.worker_idx,
)
elif scheme.context is None:
# The scheme was already initialized on the receiver (e.g. early,
# by _make_policy_factory which has no access to the inner
# collector yet). Now that we *do* have the collector, set it as
# the context so receiver-side bookkeeping (policy version,
# cascading sub-collector updates) can reach it.
scheme.context = self

# Store the scheme for later use in receive_weights()
self._receiver_schemes[model_id] = scheme
Expand Down
26 changes: 22 additions & 4 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,27 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
Received weights are automatically propagated to sub-collectors if matching model_ids exist.
Defaults to ``None``.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.
A :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform is
installed on each worker's environment, tagging every collected frame with the
current version under the ``"policy_version"`` key. Each worker's transform is
bumped after the new weights have actually been applied in that worker, so
per-frame tagging tracks real weight updates rather than rollout iterations.

Note that in asynchronous mode a batch that was already in flight when
:meth:`update_policy_weights_` is called may straddle the bump (some frames
tagged with the old version, the remainder with the new). Treat the value as
the version under which each individual frame was produced, not as a batch-level
label.

The recommended path is ``track_policy_version=True``: let the collector own
the transform. Passing a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion`
instance directly is reserved for advanced use cases that wire up a
``PolicyVersion`` **without** going through a collector. With multi-process
collectors that pre-built tracker lives in the *parent* and is not propagated
into workers, so per-frame tagging will still be driven by per-worker
transforms — favor ``True``.

Defaults to ``False``.
compact_obs (bool, optional): if ``True``, each worker drops the
observation and state keys from the ``("next", ...)`` sub-tensordict
before stacking. See
Expand Down Expand Up @@ -1320,6 +1337,7 @@ def _run_processes(self) -> None:
"trajs_per_write": self.trajs_per_write,
"init_fn": self._worker_init_fn,
"auto_register_policy_transforms": self._auto_register_policy_transforms,
"track_policy_version": self.policy_version_tracker is not None,
"pre_collect_hook": self._worker_pre_collect_hook,
"post_collect_hook": self._worker_post_collect_hook,
"compact_obs": self.compact_obs,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _main_async_collector(
trajs_per_write: int | None = None,
init_fn: Callable[[], None] | None = None,
auto_register_policy_transforms: bool | None = None,
track_policy_version: bool = False,
pre_collect_hook: Callable[[], None] | None = None,
post_collect_hook: Callable[[TensorDictBase], None] | None = None,
compact_obs: bool = False,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _main_async_collector(
trajs_per_batch=trajs_per_batch,
trajs_per_write=trajs_per_write,
auto_register_policy_transforms=auto_register_policy_transforms,
track_policy_version=track_policy_version,
pre_collect_hook=pre_collect_hook,
post_collect_hook=post_collect_hook,
compact_obs=compact_obs,
Expand Down
30 changes: 26 additions & 4 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,22 @@ class Collector(BaseCollector):
RPCDataCollector -> MultiSyncCollector -> Collector.
Defaults to ``None``.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.
A :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform is
installed on the environment, tagging every collected frame with the current version
under the ``"policy_version"`` key. The transform's version is bumped exactly once
per :meth:`update_policy_weights_` call — for multi-process collectors this happens
in each worker after the new weights have actually been applied, so per-frame
tagging tracks real weight updates rather than rollout iterations.

The recommended path is ``track_policy_version=True``: let the collector own the
transform. Passing a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion`
instance directly is reserved for advanced use cases that wire up a ``PolicyVersion``
**without** going through a collector (e.g. a hand-rolled rollout loop). Pre-creating
a transform and passing it to a collector is supported but discouraged because it
invites a divergence between the transform on the env and the one the collector
increments.

Defaults to ``False``.
compact_obs (bool, optional): if ``True``, the collector drops the
observation and state keys from the ``("next", ...)`` sub-tensordict
before stacking per-step data. Those keys are bit-for-bit identical
Expand Down Expand Up @@ -1335,6 +1347,16 @@ def update_policy_weights_(
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
)

# Bump the local PolicyVersion transform (if track_policy_version is on).
# This is the canonical bump point for the leaf collector — it covers:
# - User calls collector.update_policy_weights_() on a single-process
# SyncDataCollector / Collector.
# - The receiver-side WeightSyncScheme cascade in a multi-process
# worker (which calls inner_collector.update_policy_weights_()
# after applying weights). MultiCollector does not inherit from
# Collector, so its update_policy_weights_ does NOT bump here.
self.increment_version()

def _maybe_fallback_update(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
Expand Down
15 changes: 11 additions & 4 deletions torchrl/collectors/_single_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,17 @@ class AsyncCollector(MultiAsyncCollector):
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
the policy version.
Defaults to `False`.
A :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform is
installed on the environment, tagging every collected frame with the current version
under the ``"policy_version"`` key. The transform's version is bumped exactly once
per :meth:`update_policy_weights_` call.

The recommended path is ``track_policy_version=True``: let the collector own the
transform. Passing a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion`
instance directly is reserved for advanced use cases that wire up a ``PolicyVersion``
**without** going through a collector (e.g. a hand-rolled rollout loop).

Defaults to ``False``.

"""

Expand Down
6 changes: 5 additions & 1 deletion torchrl/weight_update/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,11 @@ def _background_receive_loop(self):
)

if weights is not None:
# Cascade weight update to sub-collectors if context supports it
# Cascade weight update to sub-collectors if context
# supports it. When the context is a leaf Collector,
# its update_policy_weights_ bumps the local
# PolicyVersion transform — no separate
# increment_version() call needed here.
model_id = self._model_id or "policy"
if self.context is not None and hasattr(
self.context, "update_policy_weights_"
Expand Down
6 changes: 5 additions & 1 deletion torchrl/weight_update/_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,11 @@ def _background_receive_loop(self):
)

if weights is not None:
# Cascade weight update to sub-collectors if context supports it
# Cascade weight update to sub-collectors if context supports it.
# When the context is a leaf Collector, its
# update_policy_weights_ also bumps the local
# PolicyVersion transform — so we don't need a
# separate increment_version() call here.
model_id = self._model_id or "policy"
if self.context is not None and hasattr(
self.context, "update_policy_weights_"
Expand Down
6 changes: 5 additions & 1 deletion torchrl/weight_update/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,11 @@ def _background_receive_loop(self):
self.model, self._receiver_shared_weights, inplace=True
)

# Cascade weight update to sub-collectors if context supports it
# Cascade weight update to sub-collectors if context supports it.
# When the context is a leaf Collector, its
# update_policy_weights_ also bumps the local
# PolicyVersion transform — so we don't need a separate
# increment_version() call here.
model_id = self._model_id or "policy"
if self.context is not None and hasattr(
self.context, "update_policy_weights_"
Expand Down
6 changes: 5 additions & 1 deletion torchrl/weight_update/weight_sync_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,11 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None:
weights = result
model_id = self._model_id or "policy"

# Cascade weight update to sub-collectors if context supports it
# Cascade weight update to sub-collectors if context supports it.
# Note on policy_version tracking: the cascade eventually reaches a
# leaf Collector.update_policy_weights_, which bumps the local
# PolicyVersion transform on the worker's env. So there is no
# separate increment_version() call here.
if self.context is not None and hasattr(self.context, "update_policy_weights_"):
self.context.update_policy_weights_(
model_id=model_id, policy_or_weights=weights
Expand Down
Loading