[WIP][PipelineRL] Normalization of new_logprobs and addition of other RL metrics#476
Draft
[WIP][PipelineRL] Normalization of new_logprobs and addition of other RL metrics#476
Conversation
Add generic denominator_batch_field to LossDef so any metric can be normalized by a pre-computed per-micro-batch scalar from the batch context, bypassing TP/SP/PP splitting entirely. For grpo_new_logprobs: compute num_docs = (labels_per_document > 0).sum() in language_model.py before any parallel splitting, giving a true per-document average regardless of variable document lengths. Only sequence_data_rank==0 contributes to num_docs. The runner all_reduces the denominator across the data group (which includes SDP ranks); if every SDP rank reported its own num_docs, a single document processed by SDP=2 would be counted twice, halving the metric. Also clamp num_labels_in_seq to avoid 0/0=nan for padding segments or fully-masked documents (loss_mask=0 there so the numerator is 0 too). Tests verify: - num_docs counts only unmasked documents - padding segments (pad_to_size) are excluded - with SDP=2, only rank 0 contributes num_docs so the all_reduce SUM across SDP ranks gives the correct denominator
When micro_batch_splits > 1, _get_model_input is called once per split on the same rank. Documents that span a split boundary appear in both splits' cropped_lengths, so both would count them without a guard. The runner sums num_docs across all splits in context.batch, so boundary documents would be counted multiple times. Fix: after the loop in get_model_inputs, set num_docs=None on all splits except the first. The first split already holds the correct count (guarded by sequence_data_rank==0 for SDP); subsequent splits get None which the runner treats as 0 via `batch_kwargs[field] or 0`. With micro_batch_splits=1 (the default) model_inputs[1:] is empty so there is no behaviour change.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add generic denominator_batch_field to LossDef so any metric can be normalized by a pre-computed per-micro-batch scalar from the batch context, bypassing TP/SP/PP splitting entirely.
For grpo_new_logprobs: compute num_docs = (labels_per_document > 0).sum() in language_model.py before any parallel splitting, giving a true per-document average regardless of variable document lengths.
Only sequence_data_rank==0 contributes to num_docs. The runner all_reduces the denominator across the data group (which includes SDP ranks); if every SDP rank reported its own num_docs, a single document processed by SDP=2 would be counted twice, halving the metric.
Also clamp num_labels_in_seq to avoid 0/0=nan for padding segments or fully-masked documents (loss_mask=0 there so the numerator is 0 too).
Tests verify:
✨ Description
Please provide a brief summary of the changes, relevant motivation, and context.
Include any related issue numbers or links to discussions, and explain why this change is necessary.
Closes #
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.