Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,52 @@ async def _pad_single_output(
extra_fields=output.extra_fields,
)

def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"""Postprocess outputs and place terminal reward on trainable tokens.

Upstream verl places reward on the last non-padding response token, which can
be an environment-observation token (response_mask == 0). For PPO/GAE this may
weaken or drop learning signal because masked positions are skipped in advantage
propagation. We relocate reward to the last trainable response token
(response_mask == 1), and fall back to the original behavior only when no
trainable response token exists.
"""
output = super()._postprocess(inputs)

scores = [item.reward_score for item in inputs]
if not all(score is not None for score in scores):
return output

response_mask = output.batch["response_mask"]
attention_mask = output.batch["attention_mask"]
prompt_ids = output.batch["prompts"]
rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)

token_positions = torch.arange(
response_mask.size(1), device=response_mask.device
).unsqueeze(0).expand_as(response_mask)
trainable_mask = response_mask > 0
has_trainable = trainable_mask.any(dim=1)
last_trainable_idx = torch.where(
trainable_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values

prompt_length = prompt_ids.size(1)
last_non_pad_idx = attention_mask[:, prompt_length:].sum(dim=1).long() - 1
last_non_pad_idx = last_non_pad_idx.clamp(min=0)
reward_positions = torch.where(
has_trainable, last_trainable_idx, last_non_pad_idx
)

rm_scores[
torch.arange(response_mask.size(0), device=response_mask.device),
reward_positions,
] = torch.tensor(scores, dtype=torch.float32, device=response_mask.device)
output.batch["rm_scores"] = rm_scores
return output

async def _run_agent_loop(
self,
sampling_params: dict[str, Any],
Expand Down
Loading
Loading