Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3350,6 +3350,42 @@ def test_explicit_false_is_silent_and_does_not_wrap(self):
auto_register_policy_transforms=False,
)

def test_policy_factory_recurrent_auto_register(self):
"""policy_factory + auto_register_policy_transforms=True must:

- Instantiate the policy, walk it for InitTracker / TensorDictPrimer
requirements, and append both transforms to the env.
- Leave the transforms with live parents (not None) so their spec
transforms work — the bug this guards against was Compose
temporarily parenting the children to a throwaway container.
- Keep env.action_spec accessible.
"""
env = GymEnv(CARTPOLE_VERSIONED())
keys_before = set(env.full_observation_spec.keys(True, True))
assert "is_init" not in keys_before
assert "recurrent_state" not in keys_before

collector = SyncDataCollector(
env,
policy_factory=self._make_recurrent_policy,
frames_per_batch=10,
total_frames=10,
auto_register_policy_transforms=True,
)
try:
keys_after = set(collector.env.full_observation_spec.keys(True, True))
assert "is_init" in keys_after
assert "recurrent_state" in keys_after
# action_spec must remain reachable — Compose-parenting bug
# used to break this by leaving InitTracker.parent=None.
assert collector.env.action_spec is not None
for transform in collector.env.transform:
assert (
transform.parent is not None
), f"transform {type(transform).__name__} has parent=None"
finally:
collector.shutdown()


def weight_reset(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
Expand Down
4 changes: 4 additions & 0 deletions torchrl/envs/transforms/_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
if input_spec["full_state_spec"] is None:
input_spec["full_state_spec"] = Composite(
shape=input_spec.shape, device=input_spec.device
)
new_state_spec = self.transform_observation_spec(input_spec["full_state_spec"])
for action_key in list(input_spec["full_action_spec"].keys(True, True)):
if action_key in new_state_spec.keys(True, True):
Expand Down
6 changes: 3 additions & 3 deletions torchrl/modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def _maybe_append_env_transforms_from_module(
transforms = _compute_missing_env_transforms(env, module, init_key)
if not transforms:
return env
from torchrl.envs.transforms import Compose

return env.append_transform(Compose(*transforms))
for transform in transforms:
env = env.append_transform(transform)
return env


def _unpad_tensors(tensors, mask, as_nested: bool = True) -> torch.Tensor:
Expand Down
Loading