Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ def policy_gradient_loss(


def _build_agent_messages(
instruction: str, *, include_image: bool = False
instruction: str,
*,
include_image: bool = False,
action_history: str = "",
) -> list[dict]:
"""Build chat messages for the GRPO agent.

Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in
the same prompt distribution the model was warm-started on.
Uses the same SYSTEM_PROMPT and prompt format as SFT training
(``next_action.py``) so GRPO operates in the same prompt
distribution the model was warm-started on.

This is the **single source of truth** for prompt construction
during both rollout collection and loss computation.
Expand All @@ -113,10 +117,15 @@ def _build_agent_messages(
include_image: If True, include an image placeholder in the user
message so ``apply_chat_template`` inserts ``<|image_pad|>``
tokens required by Qwen2.5-VL and similar VLMs.
action_history: Formatted action history from previous steps
(e.g. "Step 1: CLICK(x=0.5, y=0.3)\\nStep 2: TYPE(...)").
"""
history_text = f"{action_history}\n" if action_history else ""
text_content = (
f"Goal: {instruction}\n\n"
f"{history_text}"
"Look at the screenshot and determine the NEXT action.\n\n"
"Thought: [what element to interact with and why]\n"
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
)
if include_image:
Expand Down Expand Up @@ -161,6 +170,37 @@ class BenchmarkAction: # type: ignore[no-redef]
text = text.strip()
width, height = screen_size

# Log raw output for debugging zero-reward issues
logger.debug("Parsing VLM output (%d chars): %.200s", len(text), text)

# Extract action from "Thought: ...\nAction: ..." format (SFT output)
action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE)
if action_match:
text = action_match.group(1).strip()

# Try JSON format: {"action_type": "click", "coordinate": [x, y]}
json_match = re.search(r'\{[^}]*"action_type"[^}]*\}', text)
if json_match:
try:
import json as _json
action_data = _json.loads(json_match.group())
atype = action_data.get("action_type", "").lower()
coord = action_data.get("coordinate", action_data.get("coords", []))
if atype == "click" and len(coord) >= 2:
x_val, y_val = float(coord[0]), float(coord[1])
# Handle both normalized (0-1) and pixel coordinates
if x_val <= 1.0 and y_val <= 1.0:
x_val, y_val = x_val * width, y_val * height
return BenchmarkAction(type="click", x=int(x_val), y=int(y_val))
if atype == "type":
return BenchmarkAction(
type="type", text=action_data.get("text", "")
)
if atype in ("done", "wait"):
return BenchmarkAction(type=atype)
except Exception:
pass # Fall through to regex parsing

# CLICK(x=..., y=...)
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
if m:
Expand Down
Loading