Skip to content

Feature: add VLA policy and registry for RL#186

Open
yangchen73 wants to merge 11 commits intomainfrom
yc/vla_rl
Open

Feature: add VLA policy and registry for RL#186
yangchen73 wants to merge 11 commits intomainfrom
yc/vla_rl

Conversation

@yangchen73
Copy link
Copy Markdown
Collaborator

Description

  • Add vla_policy wrapper to integrate VLA model into RL policies.
  • Extend RL policy and training loop to support raw observations and chunked actions for VLA.
  • Introduce vla_registry to discover VLA-related factories via entry points.

Type of change

  • Enhancement (non-breaking change which improves an existing functionality)
  • New feature (non-breaking change which adds functionality)

Checklist

  • I have run the black . command to format the code base.
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • Dependencies have been updated, if applicable.

Copilot AI review requested due to automatic review settings March 16, 2026 09:29
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for integrating a VLA (vision-language-action) model into the existing RL stack by introducing a new VLAPolicy, wiring raw (hierarchical) observations + chunked actions through collection/eval/training, and adding an entry-point based backend registry.

Changes:

  • Introduce VLAPolicy and register it in the RL policy registry.
  • Extend rollout collection/training (collector, buffer, GRPO, trainer eval) to support raw observations and action chunks (action_chunk / chunk_step).
  • Add vla_registry to discover VLA backend factories via Python entry points.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
embodichain/agents/rl/vla_registry.py New entry-point based backend registry + factory creation.
embodichain/agents/rl/models/vla_policy.py New VLAPolicy wrapper for VLA inference + GRPO-compatible evaluate_actions.
embodichain/agents/rl/models/__init__.py Registers vla_policy; extends build_policy to optionally pass env/policy_cfg.
embodichain/agents/rl/collector/sync_collector.py Adds raw-observation storage and action-chunk caching + chunk_step.
embodichain/agents/rl/buffer/standard_buffer.py Adds use_raw_obs and attaches raw_obs list to shared rollout.
embodichain/agents/rl/buffer/utils.py Propagates chunk_step into transition view; adds _indices in minibatches.
embodichain/agents/rl/algo/grpo.py Passes rollout + num_envs into evaluate_actions; preserves raw fields across clone.
embodichain/agents/rl/utils/trainer.py Adjusts buffer sizing for chunked actions; updates eval loop for raw obs/chunks.
embodichain/agents/rl/models/actor_only.py Updates evaluate_actions signature to accept extra kwargs.
embodichain/agents/rl/models/actor_critic.py Updates evaluate_actions signature to accept extra kwargs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@yuecideng yuecideng self-requested a review March 16, 2026 09:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds first-class support for VLA-backed policies in the RL stack by introducing a VLA policy wrapper, an entry-point-based backend registry, and rollout/collector plumbing for raw observations + chunked actions.

Changes:

  • Introduces VLAPolicy and registers it in the RL policy registry.
  • Adds vla_registry to discover/load VLA backend factories via Python entry points.
  • Extends rollout collection/training/eval utilities to support raw_obs, chunk_step, and action-chunk caching.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
embodichain/agents/rl/vla_registry.py Entry-point discovery + factory creation for pluggable VLA backends.
embodichain/agents/rl/utils/trainer.py Trainer buffer allocation + eval loop updated for raw obs and action chunks.
embodichain/agents/rl/models/vla_policy.py New VLA-backed policy wrapper implementing chunked action inference and proxy log-prob evaluation.
embodichain/agents/rl/models/actor_only.py Broadens evaluate_actions signature to accept extra kwargs.
embodichain/agents/rl/models/actor_critic.py Broadens evaluate_actions signature to accept extra kwargs.
embodichain/agents/rl/models/__init__.py Registers vla_policy and adds env-dependent initialization path in build_policy.
embodichain/agents/rl/collector/sync_collector.py Adds raw_obs storage + action chunk caching + chunk_step tracking.
embodichain/agents/rl/buffer/utils.py Propagates chunk_step into transition view and adds minibatch _indices.
embodichain/agents/rl/buffer/standard_buffer.py Adds use_raw_obs handling and allocates rollout.raw_obs.
embodichain/agents/rl/algo/grpo.py Passes rollout context into evaluate_actions and preserves rollout attributes across clone.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

raw_obs = getattr(rollout, "raw_obs", None)
chunk_step = getattr(rollout, "chunk_step", None)
rollout = rollout.clone()
if raw_obs is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part seems useless?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fault, forgot to remove it.

if self._is_full:
raise RuntimeError("RolloutBuffer already contains a rollout.")
self._clear_dynamic_fields()
if self.use_raw_obs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of adding raw_obs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLA requires inputs of full images instead of flattened vector. Here I use raw_obs to separate VLA policy and other policy.


if use_raw_obs:
if raw_obs_list is None:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger.error

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

self.action_chunk_size = self.action_horizon
self._env = None

def set_env(self, env) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding env to policy?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After getting the raw obs from env, VLA will package them as batches which fits the requirements of their inputs. This step needs information from env

tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32)
return tensordict

def evaluate_actions(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_envs seems also useless

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copilot AI review requested due to automatic review settings March 30, 2026 03:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds VLA (Vision-Language-Action) integration into the RL stack by introducing a VLAPolicy wrapper, extending rollout collection/training to support raw observations and chunked actions, and adding a registry for VLA backends via entry points.

Changes:

  • Introduce VLAPolicy and vla_registry to load and run VLA backends inside RL policies.
  • Extend rollout collection/evaluation to support use_raw_obs and chunked actions (action_chunk + chunk_step).
  • Adjust minibatching/GRPO plumbing to pass rollout context (raw_obs, indices) into evaluate_actions.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
embodichain/agents/rl/vla_registry.py Adds entry-point-based backend discovery + factory creation for VLA backends.
embodichain/agents/rl/models/vla_policy.py New policy wrapper that runs a VLA backend and exposes RL Policy interface with action chunks + raw obs.
embodichain/agents/rl/utils/trainer.py Updates buffer sizing and evaluation loop to handle raw obs + chunked actions.
embodichain/agents/rl/train.py Passes env into build_policy for VLA policy initialization.
embodichain/agents/rl/models/init.py Registers vla_policy and extends build_policy to support env/policy_cfg and VLA initialization.
embodichain/agents/rl/collector/sync_collector.py Extends collector to populate raw_obs, generate/consume chunked actions, and track chunk_step.
embodichain/agents/rl/buffer/utils.py Propagates chunk_step into transition view; adds _indices to minibatches for mapping back to rollout.
embodichain/agents/rl/buffer/standard_buffer.py Allocates/clears raw_obs and chunk_step dynamic fields for VLA workflows.
embodichain/agents/rl/algo/grpo.py Passes rollout into evaluate_actions to support VLA log-prob evaluation from raw obs.
embodichain/agents/rl/algo/ppo.py Removes per-update rollout cloning (now relies on shared rollout lifecycle).
embodichain/agents/rl/models/actor_only.py Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking.
embodichain/agents/rl/models/actor_critic.py Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +77 to +78
backend = create_vla_backend(
"dexforce_vla",
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLAPolicy._load_vla() hardcodes the backend name to "dexforce_vla" while also excluding a backend key from vla_cfg. If configuration is meant to choose among entry-point backends, this prevents it. Read backend from config (defaulting to dexforce_vla) and pass it into create_vla_backend.

Suggested change
backend = create_vla_backend(
"dexforce_vla",
backend_name = str(self.vla_cfg.get("backend", "dexforce_vla"))
backend = create_vla_backend(
backend_name,

Copilot uses AI. Check for mistakes.
Comment on lines 271 to 275
"""Validate rollout layout expected by the collector."""
obs_dim = rollout["obs"].shape[-1]
expected_shapes = {
"obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim),
"obs": (self.env.num_envs, num_steps + 1, obs_dim),
"action": (self.env.num_envs, num_steps + 1, self.policy.action_dim),
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_rollout() uses obs_dim = rollout["obs"].shape[-1] to form the expected shape, so the obs last-dimension check can never fail (making validation ineffective for this field). If the intent is to support policies with obs_dim == 0, consider validating against self.policy.obs_dim when it’s > 0 and skipping only that last-dim check otherwise.

Copilot uses AI. Check for mistakes.
Comment on lines +78 to +82
indices = torch.randperm(total)
for start in range(0, total, batch_size):
yield rollout[indices[start : start + batch_size]]
batch_indices = indices[start : start + batch_size]
batch = rollout[batch_indices].clone()
batch["_indices"] = batch_indices
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterate_minibatches() now builds indices on the default (CPU) device. If rollout is on CUDA, indexing rollout[batch_indices] will fail because the index tensor must be on the same device. Create indices on rollout.device (or the passed device) and keep _indices consistent with that choice.

Copilot uses AI. Check for mistakes.
Comment on lines +111 to +112
except Exception:
obs_dim = max(1, obs_dim)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad except Exception fallback sets obs_dim = max(1, obs_dim), which can allocate an obs buffer with the wrong last dimension (e.g., VLAPolicy uses obs_dim=0 so this becomes 1). This tends to surface later as a confusing tensor assignment shape error. Prefer failing fast (or at least logging the exception) so buffer sizing errors are explicit.

Suggested change
except Exception:
obs_dim = max(1, obs_dim)
except Exception as exc:
raise RuntimeError(
"Failed to infer obs_dim from raw observations during Trainer "
"initialization. Check env.reset(), dict_to_tensordict, and "
"flatten_dict_observation."
) from exc

Copilot uses AI. Check for mistakes.
Comment on lines +43 to +47
factory = ep.load()
name = str(ep.name).lower()
if name not in _VLA_BACKENDS:
_VLA_BACKENDS[name] = factory
except Exception:
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-entry-point failures from ep.load() are silently swallowed, which makes plugin misconfiguration very hard to diagnose. Consider catching narrower exceptions and logging a warning/debug message with the entry point name/module when a backend fails to load.

Copilot uses AI. Check for mistakes.
Comment on lines +185 to +189
if use_action_chunk:
rollout.chunk_step[:, step_idx] = effective_step_in_chunk
# Invalidate cached_chunk on any env reset to avoid using old chunk for new episode
if (terminated | truncated).any():
cached_chunk = None
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using action chunks, the collector invalidates the entire cached_chunk whenever any environment terminates/truncates. In a vectorized env, this forces all still-running envs to throw away their remaining chunk and recompute a new one, which can be very expensive for VLA backends and changes the intended per-env chunk semantics. Consider maintaining per-env cached chunks (and invalidating only the done env indices) to avoid unnecessary recomputation.

Copilot uses AI. Check for mistakes.
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
use_raw_obs = getattr(self.policy, "use_raw_obs", False)
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behaviors (raw-observation rollouts via raw_obs and chunked-action support via chunk_step/action_chunk) are introduced here but aren’t covered by existing RL tests. Adding a focused unit test with a dummy policy exercising use_raw_obs=True and use_action_chunk=True would help prevent regressions (e.g., raw_obs population and chunk_step alignment).

Suggested change
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
if use_raw_obs and raw_obs_list is None and isinstance(rollout, TensorDict):
# Fallback to key-based access for raw observations when using a TensorDict.
# This allows 'raw_obs' to be provided either as an attribute or a field key.
if "raw_obs" in rollout.keys():
raw_obs_list = rollout.get("raw_obs")

Copilot uses AI. Check for mistakes.
Comment on lines +49 to +50
except Exception:
pass
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outer except Exception: pass suppresses all discovery errors (including programming errors) and leaves the registry empty with no signal to the caller. At minimum, log the exception; ideally catch only expected errors (e.g., missing entry point group) and let unexpected ones surface.

Copilot uses AI. Check for mistakes.
Comment on lines +133 to +136
"so that set_env and _load_vla can be called before use."
)
policy.set_env(env)
policy._load_vla()
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling the private policy._load_vla() from build_policy() forces heavyweight backend/model loading at construction time and couples the factory to a private method. Prefer lazy-loading in VLAPolicy (or exposing a public init/load hook) so config parsing and policy construction stay lightweight and copying/cloning policies doesn’t accidentally duplicate backend state.

Suggested change
"so that set_env and _load_vla can be called before use."
)
policy.set_env(env)
policy._load_vla()
"so that set_env can be called before use."
)
policy.set_env(env)

Copilot uses AI. Check for mistakes.
Comment on lines +38 to +40
_ENTRY_POINTS_DISCOVERED = True
try:
eps = entry_points(group="embodichain.vla_backends")
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ENTRY_POINTS_DISCOVERED is set to True before attempting discovery. If entry_points() fails (or the first discovery attempt is partial), subsequent calls will never retry and the registry can remain empty. Consider only marking discovery complete after a successful pass, or allowing retries on failure.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants