Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,35 @@ def policy_gradient_loss(
# ---------------------------------------------------------------------------


def _build_agent_messages(instruction: str) -> list[dict[str, str]]:
def _build_agent_messages(
instruction: str, *, include_image: bool = False
) -> list[dict]:
"""Build chat messages for the GRPO agent.

Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in
the same prompt distribution the model was warm-started on.

This is the **single source of truth** for prompt construction
during both rollout collection and loss computation.

Args:
instruction: Task instruction text.
include_image: If True, include an image placeholder in the user
message so ``apply_chat_template`` inserts ``<|image_pad|>``
tokens required by Qwen2.5-VL and similar VLMs.
"""
user_content = (
text_content = (
f"Goal: {instruction}\n\n"
"Look at the screenshot and determine the NEXT action.\n\n"
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
)
if include_image:
user_content = [
{"type": "image"},
{"type": "text", "text": text_content},
]
else:
user_content = text_content
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
Expand Down Expand Up @@ -318,7 +333,7 @@ def agent_fn(obs: Any) -> BenchmarkAction:
if raw_obs and isinstance(raw_obs, dict):
instruction = raw_obs.get("instruction", "")

messages = _build_agent_messages(instruction)
messages = _build_agent_messages(instruction, include_image=True)

if hasattr(processor, "apply_chat_template"):
text_input = processor.apply_chat_template(
Expand Down Expand Up @@ -505,7 +520,7 @@ def _compute_rollout_loss(
except Exception:
continue

messages = _build_agent_messages(instruction)
messages = _build_agent_messages(instruction, include_image=True)

# Raw text from rollout or reconstruct from DSL
raw_text = getattr(action, "_grpo_raw_text", None)
Expand Down
Loading