Skip to content
13 changes: 10 additions & 3 deletions embodichain/agents/rl/algo/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def _compute_step_group_advantages(
return advantages.view(n_envs, t_steps) * seq_mask

def update(self, rollout: TensorDict) -> Dict[str, float]:
rollout = rollout.clone()
raw_obs = getattr(rollout, "raw_obs", None)
chunk_step = getattr(rollout, "chunk_step", None)
if raw_obs is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part seems useless?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fault, forgot to remove it.

rollout.raw_obs = raw_obs
if chunk_step is not None:
rollout.chunk_step = chunk_step
num_envs = rollout.batch_size[0]
if num_envs % self.cfg.group_size != 0:
raise ValueError(
Expand Down Expand Up @@ -147,7 +152,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]:
advantages = batch["advantage"].detach()
seq_mask_batch = batch["seq_mask"].float()

eval_batch = self.policy.evaluate_actions(batch)
eval_batch = self.policy.evaluate_actions(batch, rollout=rollout)
logprobs = eval_batch["sample_log_prob"]
entropy = eval_batch["entropy"]
ratio = (logprobs - old_logprobs).exp()
Expand All @@ -166,7 +171,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]:

if self.ref_policy is not None:
with torch.no_grad():
ref_batch = self.ref_policy.evaluate_actions(batch)
ref_batch = self.ref_policy.evaluate_actions(
batch, rollout=rollout
)
ref_logprobs = ref_batch["sample_log_prob"]
log_ref_over_pi = ref_logprobs - logprobs
kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0
Expand Down
1 change: 0 additions & 1 deletion embodichain/agents/rl/algo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, cfg: PPOCfg, policy):

def update(self, rollout: TensorDict) -> Dict[str, float]:
"""Update the policy using a collected rollout."""
rollout = rollout.clone()
compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda)
flat_rollout = transition_view(rollout, flatten=True)

Expand Down
11 changes: 10 additions & 1 deletion embodichain/agents/rl/buffer/standard_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def __init__(
obs_dim: int,
action_dim: int,
device: torch.device,
use_raw_obs: bool = False,
) -> None:
self.num_envs = num_envs
self.rollout_len = rollout_len
self.obs_dim = obs_dim
self.action_dim = action_dim
self.device = device
self.use_raw_obs = use_raw_obs
self._rollout = self._allocate_rollout()
self._is_full = False

Expand All @@ -58,6 +60,8 @@ def start_rollout(self) -> TensorDict:
if self._is_full:
raise RuntimeError("RolloutBuffer already contains a rollout.")
self._clear_dynamic_fields()
if self.use_raw_obs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of adding raw_obs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLA requires inputs of full images instead of flattened vector. Here I use raw_obs to separate VLA policy and other policy.

self._rollout.raw_obs = [None] * (self.rollout_len + 1)
return self._rollout

def add(self, rollout: TensorDict) -> None:
Expand Down Expand Up @@ -97,7 +101,7 @@ def is_full(self) -> bool:

def _allocate_rollout(self) -> TensorDict:
"""Preallocate rollout storage with uniform `[num_envs, time + 1]` shape."""
return TensorDict(
td = TensorDict(
{
"obs": torch.empty(
self.num_envs,
Expand Down Expand Up @@ -153,12 +157,17 @@ def _allocate_rollout(self) -> TensorDict:
batch_size=[self.num_envs, self.rollout_len + 1],
device=self.device,
)
return td

def _clear_dynamic_fields(self) -> None:
"""Drop algorithm-added fields before reusing the shared rollout."""
for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"):
if key in self._rollout.keys():
del self._rollout[key]
if self.use_raw_obs and hasattr(self._rollout, "raw_obs"):
delattr(self._rollout, "raw_obs")
if hasattr(self._rollout, "chunk_step"):
delattr(self._rollout, "chunk_step")
self._reset_padding_slot()

def _reset_padding_slot(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions embodichain/agents/rl/buffer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict:
if key in rollout.keys():
td[key] = rollout[key][:, :-1]

if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None:
td["chunk_step"] = rollout.chunk_step

if flatten:
return td.reshape(num_envs * time_dim)
return td
Expand All @@ -72,6 +75,9 @@ def iterate_minibatches(
) -> Iterator[TensorDict]:
"""Yield shuffled minibatches from a flattened rollout."""
total = rollout.batch_size[0]
indices = torch.randperm(total, device=device)
indices = torch.randperm(total)
for start in range(0, total, batch_size):
yield rollout[indices[start : start + batch_size]]
batch_indices = indices[start : start + batch_size]
batch = rollout[batch_indices].clone()
batch["_indices"] = batch_indices
Comment on lines +78 to +82
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterate_minibatches() now builds indices on the default (CPU) device. If rollout is on CUDA, indexing rollout[batch_indices] will fail because the index tensor must be on the same device. Create indices on rollout.device (or the passed device) and keep _indices consistent with that choice.

Copilot uses AI. Check for mistakes.
yield batch
150 changes: 133 additions & 17 deletions embodichain/agents/rl/collector/sync_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensordict import TensorDict

from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
from embodichain.utils import logger
from .base import BaseCollector

__all__ = ["SyncCollector"]
Expand Down Expand Up @@ -56,32 +57,136 @@ def collect(
self.obs_td = self._reset_env()

if rollout is None:
raise ValueError(
"SyncCollector.collect() requires a preallocated rollout TensorDict."
logger.log_error(
"SyncCollector.collect() requires a preallocated rollout TensorDict.",
ValueError,
)
if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1):
raise ValueError(
logger.log_error(
"Preallocated rollout batch size mismatch: "
f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}."
f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}.",
ValueError,
)
self._validate_rollout(rollout, num_steps)
if self._supports_shared_rollout:
self.env.set_rollout_buffer(rollout)

initial_obs = flatten_dict_observation(self.obs_td)
rollout["obs"][:, 0] = initial_obs
for step_idx in range(num_steps):
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
use_raw_obs = getattr(self.policy, "use_raw_obs", False)
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behaviors (raw-observation rollouts via raw_obs and chunked-action support via chunk_step/action_chunk) are introduced here but aren’t covered by existing RL tests. Adding a focused unit test with a dummy policy exercising use_raw_obs=True and use_action_chunk=True would help prevent regressions (e.g., raw_obs population and chunk_step alignment).

Suggested change
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
if use_raw_obs and raw_obs_list is None and isinstance(rollout, TensorDict):
# Fallback to key-based access for raw observations when using a TensorDict.
# This allows 'raw_obs' to be provided either as an attribute or a field key.
if "raw_obs" in rollout.keys():
raw_obs_list = rollout.get("raw_obs")

Copilot uses AI. Check for mistakes.

if use_raw_obs:
if raw_obs_list is None:
logger.log_error(
"Policy requires raw observations, "
"but the provided rollout TensorDict has no 'raw_obs' buffer. "
"Create the rollout via RolloutBuffer or "
"start_rollout so that 'raw_obs' is allocated.",
ValueError,
)
try:
raw_obs_len = len(raw_obs_list)
except TypeError:
logger.log_error(
"Rollout field 'raw_obs' must be an indexable sequence of length "
f"{num_steps + 1} when policy.use_raw_obs=True.",
ValueError,
)
expected_len = num_steps + 1
if raw_obs_len != expected_len:
logger.log_error(
"Rollout 'raw_obs' length mismatch: "
f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. "
"Ensure the rollout was created with use_raw_obs=True and "
"its time dimension matches the requested num_steps.",
ValueError,
)

action_chunk_size = getattr(self.policy, "action_chunk_size", 0)
use_action_chunk = (
getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0
)
cached_chunk = None

if use_action_chunk:
rollout.chunk_step = torch.zeros(
self.env.num_envs,
num_steps,
dtype=torch.long,
device=self.device,
)
step_td = self.policy.get_action(step_td)

if use_raw_obs and raw_obs_list is not None:
raw_obs_list[0] = self.obs_td
rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td)
else:
rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td)

for step_idx in range(num_steps):
step_in_chunk = step_idx % action_chunk_size if use_action_chunk else 0

# At chunk boundary, or cached invalidated by env reset, we need a new chunk
need_new_chunk = use_action_chunk and (
step_in_chunk == 0 or cached_chunk is None
)

if need_new_chunk:
if use_raw_obs and raw_obs_list is not None:
step_td = TensorDict(
{"obs": raw_obs_list[step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
step_td = self.policy.get_action(step_td)
cached_chunk = step_td["action_chunk"]
action = step_td["action"]
effective_step_in_chunk = 0
elif use_action_chunk and cached_chunk is not None:
action = cached_chunk[:, step_in_chunk]
effective_step_in_chunk = step_in_chunk
step_td = TensorDict(
{
"action": action,
"sample_log_prob": torch.zeros(
action.shape[0], device=self.device, dtype=torch.float32
),
"value": torch.zeros(
action.shape[0], device=self.device, dtype=torch.float32
),
},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
if use_raw_obs and raw_obs_list is not None:
step_td = TensorDict(
{"obs": raw_obs_list[step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
step_td = self.policy.get_action(step_td)
action = step_td["action"]

next_obs, reward, terminated, truncated, env_info = self.env.step(
self._to_action_dict(step_td["action"])
self._to_action_dict(action)
)
next_obs_td = dict_to_tensordict(next_obs, self.device)
if use_action_chunk:
rollout.chunk_step[:, step_idx] = effective_step_in_chunk
# Invalidate cached_chunk on any env reset to avoid using old chunk for new episode
if (terminated | truncated).any():
cached_chunk = None
Comment on lines +185 to +189
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using action chunks, the collector invalidates the entire cached_chunk whenever any environment terminates/truncates. In a vectorized env, this forces all still-running envs to throw away their remaining chunk and recompute a new one, which can be very expensive for VLA backends and changes the intended per-env chunk semantics. Consider maintaining per-env cached chunks (and invalidating only the done env indices) to avoid unnecessary recomputation.

Copilot uses AI. Check for mistakes.
self._write_step(
rollout=rollout,
step_idx=step_idx,
Expand All @@ -95,7 +200,11 @@ def collect(
terminated=terminated,
truncated=truncated,
)
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)
if use_raw_obs and raw_obs_list is not None:
raw_obs_list[step_idx + 1] = next_obs_td
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)
else:
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)

if on_step_callback is not None:
on_step_callback(rollout[:, step_idx], env_info)
Expand All @@ -107,7 +216,12 @@ def collect(

def _attach_final_value(self, rollout: TensorDict) -> None:
"""Populate the bootstrap value for the final observed state."""
final_obs = rollout["obs"][:, -1]
use_raw_obs = getattr(self.policy, "use_raw_obs", False)
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
if use_raw_obs and raw_obs_list is not None:
final_obs = raw_obs_list[-1]
else:
final_obs = rollout["obs"][:, -1]
last_next_td = TensorDict(
{"obs": final_obs},
batch_size=[rollout.batch_size[0]],
Expand Down Expand Up @@ -155,8 +269,9 @@ def _write_env_step(

def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None:
"""Validate rollout layout expected by the collector."""
obs_dim = rollout["obs"].shape[-1]
expected_shapes = {
"obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim),
"obs": (self.env.num_envs, num_steps + 1, obs_dim),
"action": (self.env.num_envs, num_steps + 1, self.policy.action_dim),
Comment on lines 271 to 275
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_rollout() uses obs_dim = rollout["obs"].shape[-1] to form the expected shape, so the obs last-dimension check can never fail (making validation ineffective for this field). If the intent is to support policies with obs_dim == 0, consider validating against self.policy.obs_dim when it’s > 0 and skipping only that last-dim check otherwise.

Copilot uses AI. Check for mistakes.
"sample_log_prob": (self.env.num_envs, num_steps + 1),
"value": (self.env.num_envs, num_steps + 1),
Expand All @@ -168,7 +283,8 @@ def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None:
for key, expected_shape in expected_shapes.items():
actual_shape = tuple(rollout[key].shape)
if actual_shape != expected_shape:
raise ValueError(
logger.log_error(
f"Preallocated rollout field '{key}' shape mismatch: "
f"expected {expected_shape}, got {actual_shape}."
f"expected {expected_shape}, got {actual_shape}.",
ValueError,
)
21 changes: 19 additions & 2 deletions embodichain/agents/rl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import inspect
from typing import Dict, Type
from typing import Any, Dict, Optional, Type

from gymnasium import spaces
import torch
Expand All @@ -26,6 +26,7 @@
from .actor_only import ActorOnly
from .policy import Policy
from .mlp import MLP
from .vla_policy import VLAPolicy

# In-module policy registry
_POLICY_REGISTRY: Dict[str, Type[Policy]] = {}
Expand Down Expand Up @@ -63,13 +64,16 @@ def build_policy(
device: torch.device,
actor: torch.nn.Module | None = None,
critic: torch.nn.Module | None = None,
env: Optional[Any] = None,
) -> Policy:
"""Build a policy from config using spaces for extensibility.

Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while
custom policies may accept richer `obs_space` / `action_space` inputs.
For vla_policy, pass env to enable set_env and _load_vla initialization.
"""
name = policy_block["name"].lower()

if name not in _POLICY_REGISTRY:
available = ", ".join(get_registered_policy_names())
raise ValueError(
Expand Down Expand Up @@ -119,7 +123,18 @@ def build_policy(
build_kwargs["actor"] = actor
if "critic" in init_params and critic is not None:
build_kwargs["critic"] = critic
return policy_cls(**build_kwargs)
if "policy_cfg" in init_params:
build_kwargs["policy_cfg"] = policy_block
policy = policy_cls(**build_kwargs)
if name == "vla_policy":
if env is None:
raise ValueError(
"VLAPolicy requires an 'env' argument to be passed to build_policy "
"so that set_env and _load_vla can be called before use."
)
policy.set_env(env)
policy._load_vla()
Comment on lines +133 to +136
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling the private policy._load_vla() from build_policy() forces heavyweight backend/model loading at construction time and couples the factory to a private method. Prefer lazy-loading in VLAPolicy (or exposing a public init/load hook) so config parsing and policy construction stay lightweight and copying/cloning policies doesn’t accidentally duplicate backend state.

Suggested change
"so that set_env and _load_vla can be called before use."
)
policy.set_env(env)
policy._load_vla()
"so that set_env can be called before use."
)
policy.set_env(env)

Copilot uses AI. Check for mistakes.
return policy


def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP:
Expand All @@ -143,10 +158,12 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP:
# default registrations
register_policy("actor_critic", ActorCritic)
register_policy("actor_only", ActorOnly)
register_policy("vla_policy", VLAPolicy)

__all__ = [
"ActorCritic",
"ActorOnly",
"VLAPolicy",
"register_policy",
"get_registered_policy_names",
"build_policy",
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/rl/models/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict:
tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1)
return tensordict

def evaluate_actions(self, tensordict: TensorDict) -> TensorDict:
def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict:
obs = tensordict["obs"]
action = tensordict["action"]
dist = self._distribution(obs)
Expand Down
Loading
Loading