Skip to content

fix: mask token-mean loss aggregation#444

Open
haoyang9804 wants to merge 1 commit into
alibaba:mainfrom
haoyang9804:fix/agg-loss-token-mean-mask
Open

fix: mask token-mean loss aggregation#444
haoyang9804 wants to merge 1 commit into
alibaba:mainfrom
haoyang9804:fix/agg-loss-token-mean-mask

Conversation

@haoyang9804
Copy link
Copy Markdown

Summary

roll.utils.functionals.agg_loss(..., loss_agg_mode="token-mean") currently divides by the number of valid tokens but sums the entire loss_mat numerator. When padding, filtered rollout tokens, or tool/observation tokens have loss_mask=0, their per-token losses can still change actor loss, entropy loss, clip metrics, and gradients without crashing.

This patch applies loss_mask to the token-mean numerator before summing, using torch.where(loss_mask.bool(), loss_mat, 0) so masked tokens do not contribute values or gradients. It also adds a focused regression test for the scalar loss and per-token gradients.

Concrete triggering example

The bug is visible with two valid tokens and two masked tokens:

loss_mat = torch.tensor([[1.0, 100.0], [1.0, -50.0]], requires_grad=True)
loss_mask = torch.tensor([[1, 0], [1, 0]])
batch_num_tokens = 2
loss = agg_loss(loss_mat, loss_mask, "token-mean", batch_num_tokens=batch_num_tokens)

Current behavior computes (1 + 100 + 1 - 50) / 2 = 26.0 and gives every position gradient 0.5, including masked positions. Mask-respecting behavior computes (1 + 1) / 2 = 1.0 and gives gradients [[0.5, 0.0], [0.5, 0.0]].

Wrong intermediate value: token_mean_loss=26.0 with masked_token_gradient_abs_sum=1.0.
Fixed value: token_mean_loss=1.0 with masked_token_gradient_abs_sum=0.0.

Reproduction recipe

{
  "kind": "rl_sentinel_validation_recipe",
  "bug_id": "ROLL-AGGLOSS-TOKENMEAN-MASK",
  "target": "roll",
  "validation_mode": "real_roll_boundary_tensor_hook",
  "hooked_boundary": "roll.utils.functionals.agg_loss",
  "requirements": {
    "target_repo": "${TARGET_REPO}",
    "output_dir": "${OUTPUT_DIR}",
    "python": "python3",
    "required_modules": ["roll.utils.functionals", "torch"],
    "backend_required": false
  },
  "constructed_scenario": {
    "loss_agg_mode": "token-mean",
    "loss_mat": [[1.0, 100.0], [1.0, -50.0]],
    "loss_mask": [[1, 0], [1, 0]],
    "batch_num_tokens": 2
  },
  "expected_unpatched": {
    "token_mean_loss": 26.0,
    "masked_token_gradient_abs_sum": 1.0,
    "parameter_update_delta": -2.6
  },
  "expected_fixed": {
    "token_mean_loss": 1.0,
    "masked_token_gradient_abs_sum": 0.0,
    "parameter_update_delta": -0.1
  }
}

Validation runner

#!/usr/bin/env python3
import contextlib
import json
import math
import os
import subprocess
import sys
from pathlib import Path

import torch

BUG_ID = os.environ.get("BUG_ID", "ROLL-AGGLOSS-TOKENMEAN-MASK")
TARGET_REPO = Path(os.environ["TARGET_REPO"]).resolve()
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", ".")).resolve()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(TARGET_REPO))
with contextlib.redirect_stdout(sys.stderr):
    from roll.utils.functionals import agg_loss
import roll.utils.functionals as functionals

module_path = Path(functionals.__file__).resolve()
if not str(module_path).startswith(str(TARGET_REPO)):
    raise RuntimeError(f"roll.utils.functionals imported from {module_path}, expected under {TARGET_REPO}")

loss_values = torch.tensor([[1.0, 100.0], [1.0, -50.0]], dtype=torch.float32)
loss_mask = torch.tensor([[1, 0], [1, 0]], dtype=torch.long)
batch_num_tokens = int(loss_mask.sum().item())
learning_rate = 0.1

theta = torch.nn.Parameter(torch.tensor(1.0))
loss = agg_loss(theta * loss_values, loss_mask, "token-mean", batch_num_tokens=batch_num_tokens)
loss.backward()
theta_grad = float(theta.grad.detach().item())
update_delta = -learning_rate * theta_grad

leaf = loss_values.clone().detach().requires_grad_(True)
leaf_loss = agg_loss(leaf, loss_mask, "token-mean", batch_num_tokens=batch_num_tokens)
leaf_loss.backward()
leaf_grad = leaf.grad.detach()

result = {
    "bug_id": BUG_ID,
    "module_file": str(module_path),
    "commit": subprocess.check_output(["git", "-C", str(TARGET_REPO), "rev-parse", "HEAD"], text=True).strip(),
    "token_mean_loss": float(loss.detach().item()),
    "theta_gradient": theta_grad,
    "parameter_update_delta": update_delta,
    "per_token_gradient": leaf_grad.tolist(),
    "masked_token_gradient_abs_sum": float((leaf_grad * (1 - loss_mask)).abs().sum().item()),
}
(OUTPUT_DIR / "agg_loss_token_mean_validation.json").write_text(json.dumps(result, indent=2, sort_keys=True) + "\\n")
print(json.dumps(result, indent=2, sort_keys=True))

if not math.isfinite(result["token_mean_loss"]):
    raise SystemExit(2)

Observed output

Unpatched ROLL produced:

{
  "status": "reproduced",
  "observed_bad_behavior": {
    "token_mean_loss": 26.0,
    "theta_gradient": 26.0,
    "parameter_update_delta": -2.6,
    "per_token_gradient": [[0.5, 0.5], [0.5, 0.5]],
    "masked_token_gradient_abs_sum": 1.0
  },
  "expected_safe_behavior": {
    "token_mean_loss": 1.0,
    "theta_gradient": 1.0,
    "parameter_update_delta": -0.1,
    "per_token_gradient": [[0.5, 0.0], [0.5, 0.0]],
    "masked_token_gradient_abs_sum": 0.0
  },
  "attack_effect": {
    "silent": true,
    "crashed": false,
    "loss_delta": 25.0,
    "parameter_update_delta_difference": -2.5
  }
}

After the patch, the same boundary returned:

{
  "status": "fixed",
  "not_triggered": true,
  "observed_fixed_behavior": {
    "token_mean_loss": 1.0,
    "theta_gradient": 1.0,
    "parameter_update_delta": -0.1,
    "per_token_gradient": [[0.5, 0.0], [0.5, 0.0]],
    "masked_token_gradient_abs_sum": 0.0
  }
}

Root cause

agg_loss used the mask to derive batch_num_tokens, then ignored loss_mask in the token-mean numerator:

loss = (loss_mat * weights.unsqueeze(-1)).sum() / batch_num_tokens

That makes token-mean inconsistent with the mask contract used by the other aggregation modes and by ROLL actor workers that pass response_mask or final_response_mask into agg_loss.

Fix

The patch masks the numerator before applying sample weights and summing:

masked_loss = torch.where(loss_mask.bool(), loss_mat, torch.zeros_like(loss_mat))
loss = (masked_loss * weights.unsqueeze(-1)).sum() / batch_num_tokens

The regression test checks the scalar result and verifies that masked token positions receive zero gradient.

Tests and checks

python3 -m pytest -q tests/utils/test_functionals.py

git diff --check HEAD^ HEAD -- roll/utils/functionals.py tests/utils/test_functionals.py

python3 -m pre_commit install
python3 -m pre_commit run --files roll/utils/functionals.py tests/utils/test_functionals.py

All listed checks passed on the local fix branch. The fixed-boundary validation also passed with status="fixed" and not_triggered=true.

Contribution and duplicate checks

Target upstream repo: alibaba/ROLL.

Contribution evidence read from the local checkout: README.md, .pre-commit-config.yaml, and pyproject.toml. No root CONTRIBUTING.md was present in the refreshed checkout.

Duplicate checks performed before creating the fix branch:

  • BUG_FINDINGS.md search for ROLL, myROLL, alibaba/ROLL, agg_loss, token-mean, loss_mask, and batch_num_tokens found no exact ROLL duplicate. It found same-family non-ROLL mask/aggregation bugs only.
  • Local and remote myROLL branch search found no existing fix branch for this boundary before fix/agg-loss-token-mean-mask was created.
  • Local myROLL/pr_drafts was empty before this draft.
  • Historical RL-Sentinel loop artifacts and the bug DB had no exact ROLL agg_loss token-mean duplicate.
  • Upstream alibaba/ROLL GitHub issue/PR search returned alibaba/ROLL#420, an unrelated LVLM LoRA question, not this aggregation bug.
  • Refreshed upstream main still had the token-mean numerator summing loss_mat without applying loss_mask.

Related PRs or fixes

No exact upstream ROLL PR or issue was found. Same-family RL-Sentinel findings exist in other projects for mask-before-reduction invariants, including veRL agg_loss sequence-mode masked invalid values, but those are not duplicates: this bug is ROLL-specific and affects token-mean by including finite masked-token loss values in the numerator.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 18, 2026

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants