fix: fix skip_reference_policy_logprobs_calculation and skip_prev_logprobs#2443
Merged
yuki-97 merged 13 commits intoMay 12, 2026
Merged
Conversation
When force_on_policy_ratio=True, the importance sampling ratio is forced to 1.0, so prev_logprobs are unnecessary. Skip the expensive prepare_for_lp_inference() and get_logprobs() calls in both sync and async GRPO paths. In the loss function, use curr_logprobs.detach() as prev_logprobs instead of loading placeholder zeros from data. Also guards against incompatible use of seq_logprob_error_threshold with force_on_policy_ratio (the threshold requires real prev_logprobs). Part of NVIDIA-NeMo#1906 Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Run ruff format 0.9.9 (matches .pre-commit-config.yaml) on the files touched by the previous commit so the rebased branch passes the format hook on current main. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
…t into setup The assert that `loss_fn.reference_policy_kl_penalty == 0` whenever `grpo.skip_reference_policy_logprobs_calculation=True` was previously checked deep inside `grpo_train`, after policy/cluster construction. Move it into `setup()` next to the existing `force_on_policy_ratio` validation so misconfigured runs fail fast, before any expensive initialization. Also attach an explanatory message to the assert so the failure mode is self-describing. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Fixes NVIDIA-NeMo#1968: Setting skip_reference_policy_logprobs_calculation=true with reference_policy_kl_penalty=0 crashed training in three ways: Bug 1: use_reference_model() context manager crash when reference model was never initialized (AttributeError on reference_state_dict). Fix: Added early-return guard in use_reference_model() for all three worker types (megatron, dtensor v1, dtensor v2) - yields without swapping when reference model is None/missing. Bug 2: Async GRPO path unconditionally called get_reference_policy_logprobs() without checking the skip flag. Fix: Added the same skip guard as the sync path, setting zeros_like for reference_policy_logprobs when skipping. Bug 3: Missing reference_policy_logprobs key in train_data causing shape mismatches downstream in loss computation. Fix: Both sync and async paths now explicitly set train_data['reference_policy_logprobs'] = zeros_like(prev_logprobs) when skipping. Also added a _has_reference_model() helper and zeros fallback in base_policy_worker.get_reference_policy_logprobs() as defense-in-depth. Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Cherry-picked PR NVIDIA-NeMo#2174 didn't run ruff format on the worker files it touched. This commit applies the format pass so subsequent diffs stay clean. No functional changes. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights. Closes NVIDIA-NeMo#1957 Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Addresses review on PR NVIDIA-NeMo#2178 (yuki-97, terrykong): - yuki-97: "shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs." - terrykong: lists existing recipes that have reference_policy_kl_penalty=0 without setting the skip flag and would AttributeError after NVIDIA-NeMo#2178. Adds a small auto-derive block right after PR NVIDIA-NeMo#2178's `init_reference_model = ...` line: when the reference model is not loaded, set `skip_reference_policy_logprobs_calculation=True` so the sync/async training loops do not call `get_reference_policy_logprobs()` on a non-existent reference model (issue NVIDIA-NeMo#1968 Bug 1). The existing setup() assert (skip=True => kl_penalty must be 0) is unchanged; together with this auto-derive, the bidirectional invariant kl_penalty == 0 <=> ref model not loaded <=> skip ref logprobs holds for any user-provided combination of the two flags. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Adds a functional smoke test for the path enabled by PR NVIDIA-NeMo#2178 plus the auto-skip safety net added in response to yuki-97's review: > and I think it's better to add a functional test (or modify one > exist functional test) for reference_policy_kl_penalty == 0. The test runs a 2-step GRPO with reference_policy_kl_penalty=0 and without explicitly setting skip_reference_policy_logprobs_calculation, then asserts: * the auto-skip log line fires (proves setup() override worked); * the existing "Reference policy logprob calculation will be skipped" confirmation log fires; * standard probs_ratio + gen_kl_error metric envelopes pass (PR NVIDIA-NeMo#2174 zeros placeholder keeps loss math valid when KL penalty is zero). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
…ratio Adds two parametrized unit tests in tests/unit/algorithms/test_grpo.py that cover both grpo_train and async_grpo_train: - test_grpo_train_skips_reference_policy_logprobs_when_configured: guards issue NVIDIA-NeMo#1968 / PRs NVIDIA-NeMo#2174, NVIDIA-NeMo#2178 by asserting that policy.get_reference_policy_logprobs is never called when grpo.skip_reference_policy_logprobs_calculation=True. - test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratio: guards PR NVIDIA-NeMo#2177 by asserting that policy.get_logprobs is never called when loss_fn.force_on_policy_ratio=True. Both tests reuse the existing mock_grpo_components fixture and the mock_async_grpo_infrastructure helper so they require no GPU / Ray cluster and run in CI in milliseconds (modulo cold-start import cost). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Per review on the consolidation PR: the early-return guards added in
nemo_rl/models/policy/workers/{base,dtensor,dtensor_v2,megatron}_policy_worker.py
are redundant.
The grpo.py setup() now auto-enables grpo.skip_reference_policy_logprobs_calculation
when loss_fn.reference_policy_kl_penalty == 0, and the sync/async training
loops both gate the policy.get_reference_policy_logprobs() call on that flag.
This means the worker layer is never asked for reference logprobs when the
reference model is not loaded, so the worker-level guards never fire.
Also removes tests/functional/grpo_kl_zero.sh -- the four parametrized unit
tests in tests/unit/algorithms/test_grpo.py
(test_grpo_train_skips_reference_policy_logprobs_when_configured and
test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratio, each across
grpo_train + async_grpo_train) cover the same skip-paths without needing
GPUs or a real cluster.
Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Linglin Jing <linglinj@nvidia.com>
a10b32e to
d5c1532
Compare
Contributor
Author
|
/ok to test d5c1532 |
The auto-skip logic added in setup() (auto-enabling skip_reference_policy_logprobs_calculation when KL=0) reads master_config["loss_fn"]["reference_policy_kl_penalty"], so the mock config in test_setup_sglang_sets_model_path_and_parallel_flag must include this key. Fixes KeyError seen in L0_Unit_Tests_Other CI. Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
Author
|
/ok to test 92aa6ba |
The two regression tests added in this PR drive `grpo_train` / `async_grpo_train` through code paths that call `torch.zeros_like(prev_logprobs)` (PRs NVIDIA-NeMo#2174 / NVIDIA-NeMo#2178) and `torch.zeros_like(generation_logprobs)` (PR NVIDIA-NeMo#2177). Under the bare `mock_grpo_components` fixture those inputs are `MagicMock` objects, so CI failed with `TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not MagicMock` at `nemo_rl/algorithms/grpo.py:1801`. Add a `_patched_logprob_phase` context manager that swaps in real tensors for `policy.get_logprobs`, `policy.get_reference_policy_logprobs`, and `batched_message_log_to_flat_message`, and use it in both the sync and async branches of the two new tests. Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
Author
|
/ok to test c447a0d |
yuki-97
reviewed
May 11, 2026
Contributor
yuki-97
left a comment
There was a problem hiding this comment.
lgtm, thanks @jinglinglingling . one minor comment.
@yfw could you help to take a review as well?
Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
yfw
approved these changes
May 11, 2026
Contributor
|
/ok to test 2c57451 |
yuki-97
approved these changes
May 12, 2026
zswerth
pushed a commit
to zswerth/RL
that referenced
this pull request
May 12, 2026
…probs (NVIDIA-NeMo#2443) Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Nemo Assist <nemo-assist@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: zswerth <zwertheimer@nvidia.com>
youngeunkwon0405
pushed a commit
that referenced
this pull request
May 18, 2026
…probs (#2443) Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Nemo Assist <nemo-assist@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Related issues
Summary
Consolidates three open PRs onto current
mainand addresses reviewfeedback so they can ship together cleanly:
prev_logprobscomputation whenforce_on_policy_ratio=Truepolicy.get_reference_policy_logprobs()when the reference model was never loaded (issue skip_reference_policy_logprobs_calculation=true crashes training with RuntimeError / NameError #1968 Bug 1)init_reference_model = (kl_penalty > 0)to skip loading the reference model when KL penalty is zeroChanges on top of those three PRs
Move the
skip_reference_policy_logprobs_calculationassert intosetup()so misconfiguration fails before any GPU work (per @yfw on fix: skip prev_logprobs computation when force_on_policy_ratio is true #2177).Auto-enable
grpo.skip_reference_policy_logprobs_calculation=Truewhenloss_fn.reference_policy_kl_penalty == 0, so existing recipes that havekl_penalty=0without explicitly setting the skip flag (e.g.examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml) stop crashing insideuse_reference_model()(per @yuki-97 / @terrykong on fix: skip loading reference model when KL penalty is zero #2178).Add two parametrized unit tests in
tests/unit/algorithms/test_grpo.pycovering bothgrpo_trainandasync_grpo_train:test_grpo_train_skips_reference_policy_logprobs_when_configuredguards skip_reference_policy_logprobs_calculation=true crashes training with RuntimeError / NameError #1968 / fix: skip_reference_policy_logprobs_calculation=true crashes training #2174 / fix: skip loading reference model when KL penalty is zero #2178.test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratioguards fix: skip prev_logprobs computation when force_on_policy_ratio is true #2177.Drop the worker-level guards from fix: skip_reference_policy_logprobs_calculation=true crashes training #2174 — the grpo.py-level skip already prevents the bad call paths, so the worker-layer fallbacks are dead code.