-
Notifications
You must be signed in to change notification settings - Fork 8
Feature: add VLA policy and registry for RL #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8abab90
02d2c02
d5a0684
662a53d
7fe96d0
7b2287f
af373f2
1b28105
79e5840
f2cd81a
10b2b7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,12 +40,14 @@ def __init__( | |
| obs_dim: int, | ||
| action_dim: int, | ||
| device: torch.device, | ||
| use_raw_obs: bool = False, | ||
| ) -> None: | ||
| self.num_envs = num_envs | ||
| self.rollout_len = rollout_len | ||
| self.obs_dim = obs_dim | ||
| self.action_dim = action_dim | ||
| self.device = device | ||
| self.use_raw_obs = use_raw_obs | ||
| self._rollout = self._allocate_rollout() | ||
| self._is_full = False | ||
|
|
||
|
|
@@ -58,6 +60,8 @@ def start_rollout(self) -> TensorDict: | |
| if self._is_full: | ||
| raise RuntimeError("RolloutBuffer already contains a rollout.") | ||
| self._clear_dynamic_fields() | ||
| if self.use_raw_obs: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of adding
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VLA requires inputs of full images instead of flattened vector. Here I use raw_obs to separate VLA policy and other policy. |
||
| self._rollout.raw_obs = [None] * (self.rollout_len + 1) | ||
| return self._rollout | ||
|
|
||
| def add(self, rollout: TensorDict) -> None: | ||
|
|
@@ -97,7 +101,7 @@ def is_full(self) -> bool: | |
|
|
||
| def _allocate_rollout(self) -> TensorDict: | ||
| """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" | ||
| return TensorDict( | ||
| td = TensorDict( | ||
| { | ||
| "obs": torch.empty( | ||
| self.num_envs, | ||
|
|
@@ -153,12 +157,17 @@ def _allocate_rollout(self) -> TensorDict: | |
| batch_size=[self.num_envs, self.rollout_len + 1], | ||
| device=self.device, | ||
| ) | ||
| return td | ||
|
|
||
| def _clear_dynamic_fields(self) -> None: | ||
| """Drop algorithm-added fields before reusing the shared rollout.""" | ||
| for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): | ||
| if key in self._rollout.keys(): | ||
| del self._rollout[key] | ||
| if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): | ||
| delattr(self._rollout, "raw_obs") | ||
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if hasattr(self._rollout, "chunk_step"): | ||
| delattr(self._rollout, "chunk_step") | ||
| self._reset_padding_slot() | ||
|
|
||
| def _reset_padding_slot(self) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: | |
| if key in rollout.keys(): | ||
| td[key] = rollout[key][:, :-1] | ||
|
|
||
| if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None: | ||
| td["chunk_step"] = rollout.chunk_step | ||
|
|
||
| if flatten: | ||
| return td.reshape(num_envs * time_dim) | ||
| return td | ||
|
|
@@ -72,6 +75,9 @@ def iterate_minibatches( | |
| ) -> Iterator[TensorDict]: | ||
| """Yield shuffled minibatches from a flattened rollout.""" | ||
| total = rollout.batch_size[0] | ||
| indices = torch.randperm(total, device=device) | ||
| indices = torch.randperm(total) | ||
| for start in range(0, total, batch_size): | ||
| yield rollout[indices[start : start + batch_size]] | ||
| batch_indices = indices[start : start + batch_size] | ||
| batch = rollout[batch_indices].clone() | ||
| batch["_indices"] = batch_indices | ||
|
Comment on lines
+78
to
+82
|
||
| yield batch | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |||||||||||||||
| from tensordict import TensorDict | ||||||||||||||||
|
|
||||||||||||||||
| from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation | ||||||||||||||||
| from embodichain.utils import logger | ||||||||||||||||
| from .base import BaseCollector | ||||||||||||||||
|
|
||||||||||||||||
| __all__ = ["SyncCollector"] | ||||||||||||||||
|
|
@@ -56,32 +57,136 @@ def collect( | |||||||||||||||
| self.obs_td = self._reset_env() | ||||||||||||||||
|
|
||||||||||||||||
| if rollout is None: | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
| "SyncCollector.collect() requires a preallocated rollout TensorDict." | ||||||||||||||||
| logger.log_error( | ||||||||||||||||
| "SyncCollector.collect() requires a preallocated rollout TensorDict.", | ||||||||||||||||
| ValueError, | ||||||||||||||||
| ) | ||||||||||||||||
| if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1): | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
| logger.log_error( | ||||||||||||||||
| "Preallocated rollout batch size mismatch: " | ||||||||||||||||
| f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}." | ||||||||||||||||
| f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}.", | ||||||||||||||||
| ValueError, | ||||||||||||||||
| ) | ||||||||||||||||
| self._validate_rollout(rollout, num_steps) | ||||||||||||||||
| if self._supports_shared_rollout: | ||||||||||||||||
| self.env.set_rollout_buffer(rollout) | ||||||||||||||||
|
|
||||||||||||||||
| initial_obs = flatten_dict_observation(self.obs_td) | ||||||||||||||||
| rollout["obs"][:, 0] = initial_obs | ||||||||||||||||
| for step_idx in range(num_steps): | ||||||||||||||||
| step_td = TensorDict( | ||||||||||||||||
| {"obs": rollout["obs"][:, step_idx]}, | ||||||||||||||||
| batch_size=[rollout.batch_size[0]], | ||||||||||||||||
| use_raw_obs = getattr(self.policy, "use_raw_obs", False) | ||||||||||||||||
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | ||||||||||||||||
|
||||||||||||||||
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| if use_raw_obs and raw_obs_list is None and isinstance(rollout, TensorDict): | |
| # Fallback to key-based access for raw observations when using a TensorDict. | |
| # This allows 'raw_obs' to be provided either as an attribute or a field key. | |
| if "raw_obs" in rollout.keys(): | |
| raw_obs_list = rollout.get("raw_obs") |
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Mar 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using action chunks, the collector invalidates the entire cached_chunk whenever any environment terminates/truncates. In a vectorized env, this forces all still-running envs to throw away their remaining chunk and recompute a new one, which can be very expensive for VLA backends and changes the intended per-env chunk semantics. Consider maintaining per-env cached chunks (and invalidating only the done env indices) to avoid unnecessary recomputation.
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Mar 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_validate_rollout() uses obs_dim = rollout["obs"].shape[-1] to form the expected shape, so the obs last-dimension check can never fail (making validation ineffective for this field). If the intent is to support policies with obs_dim == 0, consider validating against self.policy.obs_dim when it’s > 0 and skipping only that last-dim check otherwise.
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |||||||||||||||
| from __future__ import annotations | ||||||||||||||||
|
|
||||||||||||||||
| import inspect | ||||||||||||||||
| from typing import Dict, Type | ||||||||||||||||
| from typing import Any, Dict, Optional, Type | ||||||||||||||||
|
|
||||||||||||||||
| from gymnasium import spaces | ||||||||||||||||
| import torch | ||||||||||||||||
|
|
@@ -26,6 +26,7 @@ | |||||||||||||||
| from .actor_only import ActorOnly | ||||||||||||||||
| from .policy import Policy | ||||||||||||||||
| from .mlp import MLP | ||||||||||||||||
| from .vla_policy import VLAPolicy | ||||||||||||||||
|
|
||||||||||||||||
| # In-module policy registry | ||||||||||||||||
| _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} | ||||||||||||||||
|
|
@@ -63,13 +64,16 @@ def build_policy( | |||||||||||||||
| device: torch.device, | ||||||||||||||||
| actor: torch.nn.Module | None = None, | ||||||||||||||||
| critic: torch.nn.Module | None = None, | ||||||||||||||||
| env: Optional[Any] = None, | ||||||||||||||||
| ) -> Policy: | ||||||||||||||||
| """Build a policy from config using spaces for extensibility. | ||||||||||||||||
|
|
||||||||||||||||
| Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while | ||||||||||||||||
| custom policies may accept richer `obs_space` / `action_space` inputs. | ||||||||||||||||
| For vla_policy, pass env to enable set_env and _load_vla initialization. | ||||||||||||||||
| """ | ||||||||||||||||
| name = policy_block["name"].lower() | ||||||||||||||||
|
|
||||||||||||||||
| if name not in _POLICY_REGISTRY: | ||||||||||||||||
| available = ", ".join(get_registered_policy_names()) | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
|
|
@@ -119,7 +123,18 @@ def build_policy( | |||||||||||||||
| build_kwargs["actor"] = actor | ||||||||||||||||
| if "critic" in init_params and critic is not None: | ||||||||||||||||
| build_kwargs["critic"] = critic | ||||||||||||||||
| return policy_cls(**build_kwargs) | ||||||||||||||||
| if "policy_cfg" in init_params: | ||||||||||||||||
| build_kwargs["policy_cfg"] = policy_block | ||||||||||||||||
| policy = policy_cls(**build_kwargs) | ||||||||||||||||
| if name == "vla_policy": | ||||||||||||||||
| if env is None: | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
| "VLAPolicy requires an 'env' argument to be passed to build_policy " | ||||||||||||||||
| "so that set_env and _load_vla can be called before use." | ||||||||||||||||
| ) | ||||||||||||||||
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
| policy.set_env(env) | ||||||||||||||||
| policy._load_vla() | ||||||||||||||||
yangchen73 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+133
to
+136
|
||||||||||||||||
| "so that set_env and _load_vla can be called before use." | |
| ) | |
| policy.set_env(env) | |
| policy._load_vla() | |
| "so that set_env can be called before use." | |
| ) | |
| policy.set_env(env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part seems useless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small fault, forgot to remove it.