fix: skip loading reference model when KL penalty is zero#2178
Conversation
When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights. Closes #1957 Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
| policy_config["megatron_cfg"]["train_iters"] = total_train_iters | ||
|
|
||
| # Define initialization functions that will be used in all paths | ||
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 |
There was a problem hiding this comment.
shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs.
and I think it's better to add a functional test (or modify one exist functional test) for reference_policy_kl_penalty == 0.
| policy_config["megatron_cfg"]["train_iters"] = total_train_iters | ||
|
|
||
| # Define initialization functions that will be used in all paths | ||
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 |
There was a problem hiding this comment.
BUG: Setting init_reference_model=False here prevents reference model weights from being loaded, but the sync training loop (line 1754) still calls policy.get_reference_policy_logprobs() unless grpo.skip_reference_policy_logprobs_calculation is explicitly True.
When reference_policy_kl_penalty=0 and the skip flag is unset, use_reference_model() accesses self.reference_model_state_dict which was never initialized → AttributeError.
Multiple existing configs are affected:
examples/nemo_gym/grpo_nanov3.yamlexamples/configs/recipes/llm/dapo-qwen2.5-7b.yamlexamples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yamlexamples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml
All have reference_policy_kl_penalty: 0 without setting skip_reference_policy_logprobs_calculation: true.
Suggested fix — auto-derive the skip flag:
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 | |
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 | |
| # Auto-skip reference logprob calculation when reference model is not loaded | |
| if not init_reference_model: | |
| master_config["grpo"]["skip_reference_policy_logprobs_calculation"] = True | |
Bug: async GRPO path missing reference logprob skip guardThe async GRPO path at This needs the same guard as the sync path (line 1754). Re: @yuki-97's commentGreat catch — both points are valid:
Generated by Claude Code |
Addresses review on PR NVIDIA-NeMo#2178 (yuki-97, terrykong): - yuki-97: "shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs." - terrykong: lists existing recipes that have reference_policy_kl_penalty=0 without setting the skip flag and would AttributeError after NVIDIA-NeMo#2178. Adds a small auto-derive block right after PR NVIDIA-NeMo#2178's `init_reference_model = ...` line: when the reference model is not loaded, set `skip_reference_policy_logprobs_calculation=True` so the sync/async training loops do not call `get_reference_policy_logprobs()` on a non-existent reference model (issue NVIDIA-NeMo#1968 Bug 1). The existing setup() assert (skip=True => kl_penalty must be 0) is unchanged; together with this auto-derive, the bidirectional invariant kl_penalty == 0 <=> ref model not loaded <=> skip ref logprobs holds for any user-provided combination of the two flags. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Adds a functional smoke test for the path enabled by PR NVIDIA-NeMo#2178 plus the auto-skip safety net added in response to yuki-97's review: > and I think it's better to add a functional test (or modify one > exist functional test) for reference_policy_kl_penalty == 0. The test runs a 2-step GRPO with reference_policy_kl_penalty=0 and without explicitly setting skip_reference_policy_logprobs_calculation, then asserts: * the auto-skip log line fires (proves setup() override worked); * the existing "Reference policy logprob calculation will be skipped" confirmation log fires; * standard probs_ratio + gen_kl_error metric envelopes pass (PR NVIDIA-NeMo#2174 zeros placeholder keeps loss math valid when KL penalty is zero). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
…ratio Adds two parametrized unit tests in tests/unit/algorithms/test_grpo.py that cover both grpo_train and async_grpo_train: - test_grpo_train_skips_reference_policy_logprobs_when_configured: guards issue NVIDIA-NeMo#1968 / PRs NVIDIA-NeMo#2174, NVIDIA-NeMo#2178 by asserting that policy.get_reference_policy_logprobs is never called when grpo.skip_reference_policy_logprobs_calculation=True. - test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratio: guards PR NVIDIA-NeMo#2177 by asserting that policy.get_logprobs is never called when loss_fn.force_on_policy_ratio=True. Both tests reuse the existing mock_grpo_components fixture and the mock_async_grpo_infrastructure helper so they require no GPU / Ray cluster and run in CI in milliseconds (modulo cold-start import cost). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
The two regression tests added in this PR drive `grpo_train` / `async_grpo_train` through code paths that call `torch.zeros_like(prev_logprobs)` (PRs NVIDIA-NeMo#2174 / NVIDIA-NeMo#2178) and `torch.zeros_like(generation_logprobs)` (PR NVIDIA-NeMo#2177). Under the bare `mock_grpo_components` fixture those inputs are `MagicMock` objects, so CI failed with `TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not MagicMock` at `nemo_rl/algorithms/grpo.py:1801`. Add a `_patched_logprob_phase` context manager that swaps in real tensors for `policy.get_logprobs`, `policy.get_reference_policy_logprobs`, and `batched_message_log_to_flat_message`, and use it in both the sync and async branches of the two new tests. Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
closes since included in #2443 |
When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights.
Closes #1957
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information