feat: skip logprob and reference logprob computation under certain conditions#1891
feat: skip logprob and reference logprob computation under certain conditions#1891guyueh1 wants to merge 23 commits into
Conversation
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
📝 WalkthroughWalkthroughThis PR introduces configurable flags to optimize GRPO training by enabling optional skipping of logprob computations. Configuration files are updated with Changes
Sequence DiagramsequenceDiagram
participant TrainingLoop as GRPO Training Loop
participant LogprobCalc as Logprob Calculation
participant LossFunc as Loss Function
participant DataPrep as Data Preparation
TrainingLoop->>DataPrep: Prepare training data
alt skip_reference_policy_logprobs_calculation == false
DataPrep->>LogprobCalc: Compute reference_policy_logprobs
LogprobCalc-->>DataPrep: reference_policy_logprobs
else
DataPrep-->>DataPrep: Skip reference logprob computation
end
alt skip_prev_logprobs == false
DataPrep->>LogprobCalc: Compute prev_logprobs
LogprobCalc-->>DataPrep: prev_logprobs
else
DataPrep-->>DataPrep: Set prev_logprobs to zeros
end
DataPrep->>LossFunc: Pass data with logprobs
alt force_on_policy_ratio == true
LossFunc->>LossFunc: Compute curr_logprobs on-policy
LossFunc->>LossFunc: Override prev_logprobs with curr_logprobs
else
LossFunc->>LossFunc: Use provided prev_logprobs
end
LossFunc-->>TrainingLoop: Compute loss
TrainingLoop->>TrainingLoop: Backpropagate
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1579-1597: The code currently zero-fills
train_data["prev_logprobs"] when force_on_policy_ratio is True which leads to
misleading logs and plots (token_mult_prob_error); change the handling so that
when master_config["loss_fn"].get("force_on_policy_ratio", False) is True you
either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.
🧹 Nitpick comments (3)
nemo_rl/algorithms/grpo.py (3)
337-337: Explain whyNRL_IGNORE_TP_ACCURACY_CHECKis needed whenforce_on_policy_ratiois enabled.Setting a global environment variable as a side effect of a config flag is opaque. Consider adding a comment explaining why the TP accuracy check must be disabled here, so future maintainers understand the coupling.
1602-1621: Minor:logprob_datais allocated even when both logprob computations are skipped.When both
skip_prev_logprobsandskip_reference_policy_logprobsareTrue, thelogprob_datadict on lines 1602–1608 is created but never read. This is lightweight (just references, no tensor copies), so it's not a real concern — just a nit for clarity.
2601-2633: Duplicated skip-logic betweengrpo_trainandasync_grpo_train.Lines 2601–2633 are nearly identical to lines 1579–1621 in
grpo_train. Consider extracting the flag resolution and conditionalprepare_for_lp_inference/ logprob gating into a shared helper to keep both paths in sync and reduce maintenance burden.
|
I have two questions related to logprob skipping.
|
I am still reconciling about 2 but I do agree with 3, when certain conditions are met, we should skip logprob even if user doesn't explicitly specify |
Review SummaryThe logic for skipping logprobs/reference logprobs looks correct — the conditions are consistent between Test coverage: The new
See inline comments for minor issues. |
Signed-off-by: Guyue Huang <guyueh@login-lyris01.lyris.clusters.nvidia.com>
|
/ok to test 6abdc51 |
This reverts commit 6abdc51. Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
0c818bf to
71f0ff6
Compare
|
/ok to test 71f0ff6 |
|
/claude review |
| print("▶ Computing logprobs...", flush=True) | ||
| with timer.time("policy_and_reference_logprobs"): |
There was a problem hiding this comment.
nit: When both skip_prev_logprobs and skip_reference_policy_logprobs are true, this still prints "Computing logprobs..." and constructs logprob_data only to immediately delete it. Consider wrapping the entire block (including the print) in the skip check, or adjusting the log message.
There was a problem hiding this comment.
i think it's ok to ignore
There was a problem hiding this comment.
Review summary
The core logic looks correct — skipping prev_logprobs when force_on_policy_ratio=True and skipping reference logprobs when reference_policy_kl_penalty==0 is sound, and the loss function properly uses curr_logprobs.detach() as the substitute. The guard against seq_logprob_error_threshold when prev_logprobs are skipped is a good catch. Tests cover the loss function layer well.
Two minor items flagged inline:
.get("force_on_policy_ratio", False)in grpo.py (both sync and async paths) uses a hidden boolean default, which violates config conventions. Use.get(key)without a default, or direct access since the exemplar YAML always provides it.- When both logprob computations are skipped, the code still prints "Computing logprobs..." and constructs
logprob_dataunnecessarily.
One other note: the removal of skip_reference_policy_logprobs_calculation makes the example in skills/config-conventions/SKILL.md (line 57) stale — worth updating to avoid confusion.
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
… into fuse_logprob_train
What does this PR do ?
Skip logprob and reference logprob computation under certain conditions:
loss_fn.skip_reference_policy_logprobs_calculation=true, skip reference logprob. The requirement isloss_fn.reference_kl_penalty == 0which will be checked wheneverskip_reference_policy_logprobs_calculationis true.loss_fn.force_on_policy_ratio=true, skip logprob computation. The requirement is rollout batch size == train global batch size, which will be checked wheneverforce_on_policy_ratiois true.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes