fix: mask token-mean loss aggregation#444
Open
haoyang9804 wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
roll.utils.functionals.agg_loss(..., loss_agg_mode="token-mean")currently divides by the number of valid tokens but sums the entireloss_matnumerator. When padding, filtered rollout tokens, or tool/observation tokens haveloss_mask=0, their per-token losses can still change actor loss, entropy loss, clip metrics, and gradients without crashing.This patch applies
loss_maskto the token-mean numerator before summing, usingtorch.where(loss_mask.bool(), loss_mat, 0)so masked tokens do not contribute values or gradients. It also adds a focused regression test for the scalar loss and per-token gradients.Concrete triggering example
The bug is visible with two valid tokens and two masked tokens:
Current behavior computes
(1 + 100 + 1 - 50) / 2 = 26.0and gives every position gradient0.5, including masked positions. Mask-respecting behavior computes(1 + 1) / 2 = 1.0and gives gradients[[0.5, 0.0], [0.5, 0.0]].Wrong intermediate value:
token_mean_loss=26.0withmasked_token_gradient_abs_sum=1.0.Fixed value:
token_mean_loss=1.0withmasked_token_gradient_abs_sum=0.0.Reproduction recipe
{ "kind": "rl_sentinel_validation_recipe", "bug_id": "ROLL-AGGLOSS-TOKENMEAN-MASK", "target": "roll", "validation_mode": "real_roll_boundary_tensor_hook", "hooked_boundary": "roll.utils.functionals.agg_loss", "requirements": { "target_repo": "${TARGET_REPO}", "output_dir": "${OUTPUT_DIR}", "python": "python3", "required_modules": ["roll.utils.functionals", "torch"], "backend_required": false }, "constructed_scenario": { "loss_agg_mode": "token-mean", "loss_mat": [[1.0, 100.0], [1.0, -50.0]], "loss_mask": [[1, 0], [1, 0]], "batch_num_tokens": 2 }, "expected_unpatched": { "token_mean_loss": 26.0, "masked_token_gradient_abs_sum": 1.0, "parameter_update_delta": -2.6 }, "expected_fixed": { "token_mean_loss": 1.0, "masked_token_gradient_abs_sum": 0.0, "parameter_update_delta": -0.1 } }Validation runner
Observed output
Unpatched ROLL produced:
{ "status": "reproduced", "observed_bad_behavior": { "token_mean_loss": 26.0, "theta_gradient": 26.0, "parameter_update_delta": -2.6, "per_token_gradient": [[0.5, 0.5], [0.5, 0.5]], "masked_token_gradient_abs_sum": 1.0 }, "expected_safe_behavior": { "token_mean_loss": 1.0, "theta_gradient": 1.0, "parameter_update_delta": -0.1, "per_token_gradient": [[0.5, 0.0], [0.5, 0.0]], "masked_token_gradient_abs_sum": 0.0 }, "attack_effect": { "silent": true, "crashed": false, "loss_delta": 25.0, "parameter_update_delta_difference": -2.5 } }After the patch, the same boundary returned:
{ "status": "fixed", "not_triggered": true, "observed_fixed_behavior": { "token_mean_loss": 1.0, "theta_gradient": 1.0, "parameter_update_delta": -0.1, "per_token_gradient": [[0.5, 0.0], [0.5, 0.0]], "masked_token_gradient_abs_sum": 0.0 } }Root cause
agg_lossused the mask to derivebatch_num_tokens, then ignoredloss_maskin thetoken-meannumerator:That makes token-mean inconsistent with the mask contract used by the other aggregation modes and by ROLL actor workers that pass
response_maskorfinal_response_maskintoagg_loss.Fix
The patch masks the numerator before applying sample weights and summing:
The regression test checks the scalar result and verifies that masked token positions receive zero gradient.
Tests and checks
All listed checks passed on the local fix branch. The fixed-boundary validation also passed with
status="fixed"andnot_triggered=true.Contribution and duplicate checks
Target upstream repo:
alibaba/ROLL.Contribution evidence read from the local checkout:
README.md,.pre-commit-config.yaml, andpyproject.toml. No rootCONTRIBUTING.mdwas present in the refreshed checkout.Duplicate checks performed before creating the fix branch:
BUG_FINDINGS.mdsearch forROLL,myROLL,alibaba/ROLL,agg_loss,token-mean,loss_mask, andbatch_num_tokensfound no exact ROLL duplicate. It found same-family non-ROLL mask/aggregation bugs only.myROLLbranch search found no existing fix branch for this boundary beforefix/agg-loss-token-mean-maskwas created.myROLL/pr_draftswas empty before this draft.agg_losstoken-mean duplicate.alibaba/ROLLGitHub issue/PR search returnedalibaba/ROLL#420, an unrelated LVLM LoRA question, not this aggregation bug.loss_matwithout applyingloss_mask.Related PRs or fixes
No exact upstream ROLL PR or issue was found. Same-family RL-Sentinel findings exist in other projects for mask-before-reduction invariants, including veRL
agg_losssequence-mode masked invalid values, but those are not duplicates: this bug is ROLL-specific and affectstoken-meanby including finite masked-token loss values in the numerator.