Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions openadapt_ml/training/grpo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,12 @@ class GRPOConfig:
save_every_steps: int = 50
output_dir: str = "checkpoints/grpo"

# Generation
max_new_tokens: int = 2048 # Token budget per step. Reasoning models need
# 1000+ tokens (thought + action). 100 truncates mid-reasoning → unparseable.

# Task configs
task_dir: str | None = None # Directory of TaskConfig YAMLs for milestone rewards

# Stuck detection
stuck_window: int = 3
57 changes: 53 additions & 4 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ def policy_gradient_loss(


def _build_agent_messages(
instruction: str, *, include_image: bool = False
instruction: str,
*,
include_image: bool = False,
action_history: str = "",
) -> list[dict]:
"""Build chat messages for the GRPO agent.

Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in
the same prompt distribution the model was warm-started on.
Uses the same SYSTEM_PROMPT and prompt format as SFT training
(``next_action.py``) so GRPO operates in the same prompt
distribution the model was warm-started on.

This is the **single source of truth** for prompt construction
during both rollout collection and loss computation.
Expand All @@ -113,10 +117,15 @@ def _build_agent_messages(
include_image: If True, include an image placeholder in the user
message so ``apply_chat_template`` inserts ``<|image_pad|>``
tokens required by Qwen2.5-VL and similar VLMs.
action_history: Formatted action history from previous steps
(e.g. "Step 1: CLICK(x=0.5, y=0.3)\\nStep 2: TYPE(...)").
"""
history_text = f"{action_history}\n" if action_history else ""
text_content = (
f"Goal: {instruction}\n\n"
f"{history_text}"
"Look at the screenshot and determine the NEXT action.\n\n"
"Thought: [what element to interact with and why]\n"
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
)
if include_image:
Expand Down Expand Up @@ -161,6 +170,37 @@ class BenchmarkAction: # type: ignore[no-redef]
text = text.strip()
width, height = screen_size

# Log raw output for debugging zero-reward issues
logger.debug("Parsing VLM output (%d chars): %.200s", len(text), text)

# Extract action from "Thought: ...\nAction: ..." format (SFT output)
action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE)
if action_match:
text = action_match.group(1).strip()

# Try JSON format: {"action_type": "click", "coordinate": [x, y]}
json_match = re.search(r'\{[^}]*"action_type"[^}]*\}', text)
if json_match:
try:
import json as _json
action_data = _json.loads(json_match.group())
atype = action_data.get("action_type", "").lower()
coord = action_data.get("coordinate", action_data.get("coords", []))
if atype == "click" and len(coord) >= 2:
x_val, y_val = float(coord[0]), float(coord[1])
# Handle both normalized (0-1) and pixel coordinates
if x_val <= 1.0 and y_val <= 1.0:
x_val, y_val = x_val * width, y_val * height
return BenchmarkAction(type="click", x=int(x_val), y=int(y_val))
if atype == "type":
return BenchmarkAction(
type="type", text=action_data.get("text", "")
)
if atype in ("done", "wait"):
return BenchmarkAction(type=atype)
except Exception:
pass # Fall through to regex parsing

# CLICK(x=..., y=...)
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
if m:
Expand Down Expand Up @@ -348,7 +388,7 @@ def agent_fn(obs: Any) -> BenchmarkAction:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
max_new_tokens=self._config.max_new_tokens,
temperature=temperature,
do_sample=True,
)
Expand All @@ -365,6 +405,15 @@ def agent_fn(obs: Any) -> BenchmarkAction:
except Exception:
pass

# Warn if output was likely truncated (hit max_new_tokens)
gen_len = outputs[0].shape[0] - inputs["input_ids"].shape[1]
if gen_len >= self._config.max_new_tokens - 1:
logger.warning(
"Generation hit max_new_tokens=%d — output may be truncated. "
"Increase config.max_new_tokens if actions aren't parsed.",
self._config.max_new_tokens,
)

action = _parse_vlm_output_to_action(decoded, screen_size=screen_size)

# Store raw VLM output for accurate loss computation
Expand Down
Loading