Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openadapt_ml/training/grpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from openadapt_ml.training.grpo.reward import (
binary_task_success,
compute_group_advantages,
evaluate_milestones_screenshot,
)
from openadapt_ml.training.grpo.rollout_collector import (
GRPORolloutCollector,
Expand Down Expand Up @@ -86,6 +87,7 @@ def __getattr__(name: str):
"Rollout",
"binary_task_success",
"compute_group_advantages",
"evaluate_milestones_screenshot",
"policy_gradient_loss",
"grpo_loss",
"parse_vlm_output_to_action",
Expand Down
11 changes: 11 additions & 0 deletions openadapt_ml/training/grpo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class GRPOConfig:
server_url: URL of the WAA server for live environment interaction.
evaluate_url: URL of the evaluate server. If None, defaults to server_url.
task_ids: List of WAA task IDs to train on.
task_dir: Path to a directory of YAML task config files. When set,
the trainer loads TaskConfig objects and uses milestone-based
reward evaluation locally (no /evaluate endpoint needed).
If task_ids is empty, task IDs are auto-populated from the
loaded configs.
learning_rate: Optimizer learning rate for LoRA parameter updates.
num_training_steps: Total number of GRPO training steps (outer loop).
save_every_steps: Checkpoint frequency.
Expand Down Expand Up @@ -69,6 +74,12 @@ class GRPOConfig:
task_ids: list[str] = field(default_factory=list)
screen_size: tuple[int, int] = (1920, 1080) # (width, height)

# Task configuration directory (YAML files with milestones for dense rewards).
# When set, the trainer loads TaskConfig objects from this directory and
# uses milestone-based reward evaluation locally, without needing the
# WAA /evaluate endpoint. Requires openadapt-evals to be installed.
task_dir: str | None = None

# Training
learning_rate: float = 5e-6
num_training_steps: int = 1000
Expand Down
96 changes: 96 additions & 0 deletions openadapt_ml/training/grpo/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@
GRPO computes advantages relative to the group mean rather than using
a learned value function, which is simpler and works well for sparse
binary rewards (task success/failure).

Also provides ``evaluate_milestones_screenshot``, a standalone utility
that evaluates milestone-based rewards from a screenshot without needing
the WAA /evaluate endpoint. This is the local-evaluation path used by
the standalone GRPO trainer when ``--task-dir`` is set.
"""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)


def binary_task_success(score: float, threshold: float = 0.5) -> float:
"""Convert evaluator score to binary reward.
Expand Down Expand Up @@ -54,3 +63,90 @@ def compute_group_advantages(rewards: list[float]) -> list[float]:
return [0.0] * n

return [(r - mean) / (std + eps) for r in rewards]


def evaluate_milestones_screenshot(
task_config: object,
screenshot_bytes: bytes,
vlm_model: str = "gpt-4.1-mini",
vlm_provider: str = "openai",
) -> float:
"""Evaluate milestone-based rewards from a screenshot (no server needed).

Iterates over the milestones in a TaskConfig and evaluates each
``screenshot``-type milestone using a VLM judge. Non-screenshot
milestones are skipped (they require a live server).

This is a standalone utility that can be called independently of the
trainer, e.g.::

from openadapt_ml.training.grpo.reward import evaluate_milestones_screenshot
reward = evaluate_milestones_screenshot(task_config, screenshot_bytes)

Args:
task_config: A ``TaskConfig`` instance (from ``openadapt_evals.task_config``).
Must have a ``milestones`` attribute (list of ``Milestone`` objects).
screenshot_bytes: PNG screenshot bytes to evaluate against.
vlm_model: VLM model name for the judge.
vlm_provider: VLM provider (``"openai"`` or ``"anthropic"``).

Returns:
Fraction of screenshot milestones that passed (0.0 to 1.0).
Returns 0.0 if there are no milestones or no screenshot milestones.
"""
milestones = getattr(task_config, "milestones", None)
if not milestones:
return 0.0

# Only evaluate screenshot-type milestones locally
screenshot_milestones = [
ms for ms in milestones
if getattr(ms.check, "check", None) == "screenshot"
]
if not screenshot_milestones:
return 0.0

try:
from openadapt_evals.vlm_evaluator import vlm_judge
except ImportError:
logger.warning(
"openadapt-evals is not installed; cannot evaluate screenshot "
"milestones. Install with: pip install openadapt-evals"
)
return 0.0

passed = 0
for ms in screenshot_milestones:
description = getattr(ms.check, "description", None) or ""
if not description:
continue
try:
success, _confidence = vlm_judge(
screenshot_bytes,
description,
model=vlm_model,
provider=vlm_provider,
)
if success:
passed += 1
logger.debug(
"Milestone '%s': %s",
getattr(ms, "name", "?"),
"PASS" if success else "FAIL",
)
except Exception as exc:
logger.warning(
"Milestone '%s' evaluation failed: %s",
getattr(ms, "name", "?"),
exc,
)

total = len(screenshot_milestones)
score = passed / total if total > 0 else 0.0
logger.info(
"Milestone evaluation: %d/%d screenshot milestones passed (%.2f)",
passed,
total,
score,
)
return score
15 changes: 14 additions & 1 deletion openadapt_ml/training/grpo/rollout_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,27 @@ class GRPORolloutCollector:

Args:
config: GRPO training configuration.
task_configs: Optional dict mapping task_id -> TaskConfig. When
provided, task configs are loaded into the RLEnvironment for
milestone-based dense reward evaluation.

Raises:
ImportError: If openadapt-evals is not installed.
"""

def __init__(self, config: GRPOConfig) -> None:
def __init__(
self,
config: GRPOConfig,
task_configs: dict[str, Any] | None = None,
) -> None:
if RLEnvironment is None:
raise ImportError(
"openadapt-evals is required for rollout collection. "
"Install it with: uv add openadapt-evals"
)

self._config = config
self._task_configs = task_configs or {}
self._adapter = WAALiveAdapter(
WAALiveConfig(
server_url=config.server_url,
Expand Down Expand Up @@ -123,6 +131,11 @@ def collect_group(

rollouts: list[Rollout] = []

# Load task config into the environment for dense milestone rewards
if task_id in self._task_configs:
tc = self._task_configs[task_id]
self._env.load_task_config(tc)

for i in range(self._config.num_rollouts_per_step):
logger.info(
"Collecting rollout %d/%d for task %s",
Expand Down
Loading
Loading