Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/llm/test_llm_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ExecuteToolsInOrder,
IncrementalTokenizer,
JSONCallParser,
PolicyVersion,
ToolCall,
ToolRegistry,
XMLBlockParser,
Expand Down Expand Up @@ -718,3 +719,31 @@ def test_empty_history_handling(self, tokenizer):
assert ("tokens", "prompt") in result.keys(True, True)
tokens = result.get(("tokens", "prompt"), as_list=True)
assert tokens[0].numel() > 0


class TestPolicyVersion:
def test_int_version_dtype_and_device(self):
"""Integer policy version must stay int64 and follow the tensordict device.

Regression for a bug where ``version_type="int"`` was cast to float
and dropped the device, producing CPU float tensors that mismatched
the surrounding tensordict.
"""
import torch

transform = PolicyVersion(version_type="int")
transform.version = 7

td = TensorDict(batch_size=(4,))
out = transform._step(td, td.copy())
version = out.get("policy_version")
assert version.dtype == torch.int64
assert version.shape == (4,)
assert torch.equal(version, torch.full((4,), 7, dtype=torch.int64))

if torch.cuda.is_available():
td_cuda = TensorDict(batch_size=(4,), device="cuda")
out_cuda = transform._step(td_cuda, td_cuda.copy())
version_cuda = out_cuda.get("policy_version")
assert version_cuda.dtype == torch.int64
assert version_cuda.device.type == "cuda"
11 changes: 9 additions & 2 deletions torchrl/envs/llm/transforms/policy_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,15 @@ def _step(
if self.version_type in (str, "uuid"):
version = NonTensorData(self.version).expand(next_tensordict.shape)
elif self.version_type in (int, "int"):
# Cast to float for torch.full
version = torch.full(next_tensordict.shape, float(cast(int, self.version)))
device = next_tensordict.device
kwargs = {"dtype": torch.int64}
if device is not None:
kwargs["device"] = device
version = torch.full(
next_tensordict.shape,
cast(int, self.version),
**kwargs,
)
else:
raise ValueError(f"Invalid version type: {self.version_type}")

Expand Down
Loading