Feature: add VLA policy and registry for RL#186
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for integrating a VLA (vision-language-action) model into the existing RL stack by introducing a new VLAPolicy, wiring raw (hierarchical) observations + chunked actions through collection/eval/training, and adding an entry-point based backend registry.
Changes:
- Introduce
VLAPolicyand register it in the RL policy registry. - Extend rollout collection/training (collector, buffer, GRPO, trainer eval) to support raw observations and action chunks (
action_chunk/chunk_step). - Add
vla_registryto discover VLA backend factories via Python entry points.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
embodichain/agents/rl/vla_registry.py |
New entry-point based backend registry + factory creation. |
embodichain/agents/rl/models/vla_policy.py |
New VLAPolicy wrapper for VLA inference + GRPO-compatible evaluate_actions. |
embodichain/agents/rl/models/__init__.py |
Registers vla_policy; extends build_policy to optionally pass env/policy_cfg. |
embodichain/agents/rl/collector/sync_collector.py |
Adds raw-observation storage and action-chunk caching + chunk_step. |
embodichain/agents/rl/buffer/standard_buffer.py |
Adds use_raw_obs and attaches raw_obs list to shared rollout. |
embodichain/agents/rl/buffer/utils.py |
Propagates chunk_step into transition view; adds _indices in minibatches. |
embodichain/agents/rl/algo/grpo.py |
Passes rollout + num_envs into evaluate_actions; preserves raw fields across clone. |
embodichain/agents/rl/utils/trainer.py |
Adjusts buffer sizing for chunked actions; updates eval loop for raw obs/chunks. |
embodichain/agents/rl/models/actor_only.py |
Updates evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/actor_critic.py |
Updates evaluate_actions signature to accept extra kwargs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Adds first-class support for VLA-backed policies in the RL stack by introducing a VLA policy wrapper, an entry-point-based backend registry, and rollout/collector plumbing for raw observations + chunked actions.
Changes:
- Introduces
VLAPolicyand registers it in the RL policy registry. - Adds
vla_registryto discover/load VLA backend factories via Python entry points. - Extends rollout collection/training/eval utilities to support
raw_obs,chunk_step, and action-chunk caching.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
embodichain/agents/rl/vla_registry.py |
Entry-point discovery + factory creation for pluggable VLA backends. |
embodichain/agents/rl/utils/trainer.py |
Trainer buffer allocation + eval loop updated for raw obs and action chunks. |
embodichain/agents/rl/models/vla_policy.py |
New VLA-backed policy wrapper implementing chunked action inference and proxy log-prob evaluation. |
embodichain/agents/rl/models/actor_only.py |
Broadens evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/actor_critic.py |
Broadens evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/__init__.py |
Registers vla_policy and adds env-dependent initialization path in build_policy. |
embodichain/agents/rl/collector/sync_collector.py |
Adds raw_obs storage + action chunk caching + chunk_step tracking. |
embodichain/agents/rl/buffer/utils.py |
Propagates chunk_step into transition view and adds minibatch _indices. |
embodichain/agents/rl/buffer/standard_buffer.py |
Adds use_raw_obs handling and allocates rollout.raw_obs. |
embodichain/agents/rl/algo/grpo.py |
Passes rollout context into evaluate_actions and preserves rollout attributes across clone. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raw_obs = getattr(rollout, "raw_obs", None) | ||
| chunk_step = getattr(rollout, "chunk_step", None) | ||
| rollout = rollout.clone() | ||
| if raw_obs is not None: |
There was a problem hiding this comment.
This part seems useless?
There was a problem hiding this comment.
Small fault, forgot to remove it.
| if self._is_full: | ||
| raise RuntimeError("RolloutBuffer already contains a rollout.") | ||
| self._clear_dynamic_fields() | ||
| if self.use_raw_obs: |
There was a problem hiding this comment.
What is the purpose of adding raw_obs?
There was a problem hiding this comment.
VLA requires inputs of full images instead of flattened vector. Here I use raw_obs to separate VLA policy and other policy.
|
|
||
| if use_raw_obs: | ||
| if raw_obs_list is None: | ||
| raise ValueError( |
| self.action_chunk_size = self.action_horizon | ||
| self._env = None | ||
|
|
||
| def set_env(self, env) -> None: |
There was a problem hiding this comment.
Why adding env to policy?
There was a problem hiding this comment.
After getting the raw obs from env, VLA will package them as batches which fits the requirements of their inputs. This step needs information from env
| tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32) | ||
| return tensordict | ||
|
|
||
| def evaluate_actions( |
There was a problem hiding this comment.
num_envs seems also useless
There was a problem hiding this comment.
Pull request overview
Adds VLA (Vision-Language-Action) integration into the RL stack by introducing a VLAPolicy wrapper, extending rollout collection/training to support raw observations and chunked actions, and adding a registry for VLA backends via entry points.
Changes:
- Introduce
VLAPolicyandvla_registryto load and run VLA backends inside RL policies. - Extend rollout collection/evaluation to support
use_raw_obsand chunked actions (action_chunk+chunk_step). - Adjust minibatching/GRPO plumbing to pass rollout context (
raw_obs, indices) intoevaluate_actions.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| embodichain/agents/rl/vla_registry.py | Adds entry-point-based backend discovery + factory creation for VLA backends. |
| embodichain/agents/rl/models/vla_policy.py | New policy wrapper that runs a VLA backend and exposes RL Policy interface with action chunks + raw obs. |
| embodichain/agents/rl/utils/trainer.py | Updates buffer sizing and evaluation loop to handle raw obs + chunked actions. |
| embodichain/agents/rl/train.py | Passes env into build_policy for VLA policy initialization. |
| embodichain/agents/rl/models/init.py | Registers vla_policy and extends build_policy to support env/policy_cfg and VLA initialization. |
| embodichain/agents/rl/collector/sync_collector.py | Extends collector to populate raw_obs, generate/consume chunked actions, and track chunk_step. |
| embodichain/agents/rl/buffer/utils.py | Propagates chunk_step into transition view; adds _indices to minibatches for mapping back to rollout. |
| embodichain/agents/rl/buffer/standard_buffer.py | Allocates/clears raw_obs and chunk_step dynamic fields for VLA workflows. |
| embodichain/agents/rl/algo/grpo.py | Passes rollout into evaluate_actions to support VLA log-prob evaluation from raw obs. |
| embodichain/agents/rl/algo/ppo.py | Removes per-update rollout cloning (now relies on shared rollout lifecycle). |
| embodichain/agents/rl/models/actor_only.py | Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking. |
| embodichain/agents/rl/models/actor_critic.py | Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| backend = create_vla_backend( | ||
| "dexforce_vla", |
There was a problem hiding this comment.
VLAPolicy._load_vla() hardcodes the backend name to "dexforce_vla" while also excluding a backend key from vla_cfg. If configuration is meant to choose among entry-point backends, this prevents it. Read backend from config (defaulting to dexforce_vla) and pass it into create_vla_backend.
| backend = create_vla_backend( | |
| "dexforce_vla", | |
| backend_name = str(self.vla_cfg.get("backend", "dexforce_vla")) | |
| backend = create_vla_backend( | |
| backend_name, |
| """Validate rollout layout expected by the collector.""" | ||
| obs_dim = rollout["obs"].shape[-1] | ||
| expected_shapes = { | ||
| "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), | ||
| "obs": (self.env.num_envs, num_steps + 1, obs_dim), | ||
| "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), |
There was a problem hiding this comment.
_validate_rollout() uses obs_dim = rollout["obs"].shape[-1] to form the expected shape, so the obs last-dimension check can never fail (making validation ineffective for this field). If the intent is to support policies with obs_dim == 0, consider validating against self.policy.obs_dim when it’s > 0 and skipping only that last-dim check otherwise.
| indices = torch.randperm(total) | ||
| for start in range(0, total, batch_size): | ||
| yield rollout[indices[start : start + batch_size]] | ||
| batch_indices = indices[start : start + batch_size] | ||
| batch = rollout[batch_indices].clone() | ||
| batch["_indices"] = batch_indices |
There was a problem hiding this comment.
iterate_minibatches() now builds indices on the default (CPU) device. If rollout is on CUDA, indexing rollout[batch_indices] will fail because the index tensor must be on the same device. Create indices on rollout.device (or the passed device) and keep _indices consistent with that choice.
| except Exception: | ||
| obs_dim = max(1, obs_dim) |
There was a problem hiding this comment.
The broad except Exception fallback sets obs_dim = max(1, obs_dim), which can allocate an obs buffer with the wrong last dimension (e.g., VLAPolicy uses obs_dim=0 so this becomes 1). This tends to surface later as a confusing tensor assignment shape error. Prefer failing fast (or at least logging the exception) so buffer sizing errors are explicit.
| except Exception: | |
| obs_dim = max(1, obs_dim) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Failed to infer obs_dim from raw observations during Trainer " | |
| "initialization. Check env.reset(), dict_to_tensordict, and " | |
| "flatten_dict_observation." | |
| ) from exc |
| factory = ep.load() | ||
| name = str(ep.name).lower() | ||
| if name not in _VLA_BACKENDS: | ||
| _VLA_BACKENDS[name] = factory | ||
| except Exception: |
There was a problem hiding this comment.
Per-entry-point failures from ep.load() are silently swallowed, which makes plugin misconfiguration very hard to diagnose. Consider catching narrower exceptions and logging a warning/debug message with the entry point name/module when a backend fails to load.
| if use_action_chunk: | ||
| rollout.chunk_step[:, step_idx] = effective_step_in_chunk | ||
| # Invalidate cached_chunk on any env reset to avoid using old chunk for new episode | ||
| if (terminated | truncated).any(): | ||
| cached_chunk = None |
There was a problem hiding this comment.
When using action chunks, the collector invalidates the entire cached_chunk whenever any environment terminates/truncates. In a vectorized env, this forces all still-running envs to throw away their remaining chunk and recompute a new one, which can be very expensive for VLA backends and changes the intended per-env chunk semantics. Consider maintaining per-env cached chunks (and invalidating only the done env indices) to avoid unnecessary recomputation.
| {"obs": rollout["obs"][:, step_idx]}, | ||
| batch_size=[rollout.batch_size[0]], | ||
| use_raw_obs = getattr(self.policy, "use_raw_obs", False) | ||
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None |
There was a problem hiding this comment.
New behaviors (raw-observation rollouts via raw_obs and chunked-action support via chunk_step/action_chunk) are introduced here but aren’t covered by existing RL tests. Adding a focused unit test with a dummy policy exercising use_raw_obs=True and use_action_chunk=True would help prevent regressions (e.g., raw_obs population and chunk_step alignment).
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| if use_raw_obs and raw_obs_list is None and isinstance(rollout, TensorDict): | |
| # Fallback to key-based access for raw observations when using a TensorDict. | |
| # This allows 'raw_obs' to be provided either as an attribute or a field key. | |
| if "raw_obs" in rollout.keys(): | |
| raw_obs_list = rollout.get("raw_obs") |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
The outer except Exception: pass suppresses all discovery errors (including programming errors) and leaves the registry empty with no signal to the caller. At minimum, log the exception; ideally catch only expected errors (e.g., missing entry point group) and let unexpected ones surface.
| "so that set_env and _load_vla can be called before use." | ||
| ) | ||
| policy.set_env(env) | ||
| policy._load_vla() |
There was a problem hiding this comment.
Calling the private policy._load_vla() from build_policy() forces heavyweight backend/model loading at construction time and couples the factory to a private method. Prefer lazy-loading in VLAPolicy (or exposing a public init/load hook) so config parsing and policy construction stay lightweight and copying/cloning policies doesn’t accidentally duplicate backend state.
| "so that set_env and _load_vla can be called before use." | |
| ) | |
| policy.set_env(env) | |
| policy._load_vla() | |
| "so that set_env can be called before use." | |
| ) | |
| policy.set_env(env) |
| _ENTRY_POINTS_DISCOVERED = True | ||
| try: | ||
| eps = entry_points(group="embodichain.vla_backends") |
There was a problem hiding this comment.
_ENTRY_POINTS_DISCOVERED is set to True before attempting discovery. If entry_points() fails (or the first discovery attempt is partial), subsequent calls will never retry and the registry can remain empty. Consider only marking discovery complete after a successful pass, or allowing retries on failure.
Description
vla_policywrapper to integrate VLA model into RL policies.vla_registryto discover VLA-related factories via entry points.Type of change
Checklist
black .command to format the code base.