Skip to content

fix: skip loading reference model when KL penalty is zero#2178

Closed
yfw wants to merge 1 commit into
mainfrom
yifu/skip-ref-model-when-no-kl
Closed

fix: skip loading reference model when KL penalty is zero#2178
yfw wants to merge 1 commit into
mainfrom
yifu/skip-ref-model-when-no-kl

Conversation

@yfw
Copy link
Copy Markdown
Contributor

@yfw yfw commented Mar 31, 2026

When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights.

Closes #1957

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

When reference_policy_kl_penalty is 0, the reference model is unused
during GRPO training. Pass init_reference_model=False to avoid
allocating memory for the reference model weights.

Closes #1957

Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@yfw yfw requested a review from a team as a code owner March 31, 2026 00:01
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 31, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

policy_config["megatron_cfg"]["train_iters"] = total_train_iters

# Define initialization functions that will be used in all paths
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs.

and I think it's better to add a functional test (or modify one exist functional test) for reference_policy_kl_penalty == 0.

policy_config["megatron_cfg"]["train_iters"] = total_train_iters

# Define initialization functions that will be used in all paths
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BUG: Setting init_reference_model=False here prevents reference model weights from being loaded, but the sync training loop (line 1754) still calls policy.get_reference_policy_logprobs() unless grpo.skip_reference_policy_logprobs_calculation is explicitly True.

When reference_policy_kl_penalty=0 and the skip flag is unset, use_reference_model() accesses self.reference_model_state_dict which was never initialized → AttributeError.

Multiple existing configs are affected:

  • examples/nemo_gym/grpo_nanov3.yaml
  • examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml
  • examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml
  • examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml

All have reference_policy_kl_penalty: 0 without setting skip_reference_policy_logprobs_calculation: true.

Suggested fix — auto-derive the skip flag:

Suggested change
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
# Auto-skip reference logprob calculation when reference model is not loaded
if not init_reference_model:
master_config["grpo"]["skip_reference_policy_logprobs_calculation"] = True

@terrykong
Copy link
Copy Markdown
Collaborator

Bug: async GRPO path missing reference logprob skip guard

The async GRPO path at nemo_rl/algorithms/grpo.py:2795 unconditionally calls policy.get_reference_policy_logprobs() — it doesn't check skip_reference_policy_logprobs_calculation. Even if the sync path is fixed, async GRPO with reference_policy_kl_penalty=0 will crash with AttributeError because reference_model_state_dict / reference_state_dict was never initialized.

This needs the same guard as the sync path (line 1754).


Re: @yuki-97's comment

Great catch — both points are valid:

  1. skip_reference_policy_logprobs_calculation: Without this flag set to true, the sync path (line 1754) will still call get_reference_policy_logprobs(), which invokes use_reference_model() and accesses self.reference_model_state_dict that was never initialized → AttributeError. Multiple existing configs with reference_policy_kl_penalty: 0 don't set the skip flag (e.g., grpo_nanov3.yaml, dapo-qwen2.5-7b.yaml, grpo-deepscaler-1.5b-8K.yaml).

  2. Functional test: Agreed — a test for reference_policy_kl_penalty == 0 would catch this. Note the async path needs the same fix.

Generated by Claude Code

jinglinglingling added a commit to jinglinglingling/RL that referenced this pull request May 9, 2026
Addresses review on PR NVIDIA-NeMo#2178 (yuki-97, terrykong):

- yuki-97: "shall we set skip_reference_policy_logprobs_calculation to
  True in this situation? otherwise I guess we will get error when
  calling get_reference_policy_logprobs."
- terrykong: lists existing recipes that have reference_policy_kl_penalty=0
  without setting the skip flag and would AttributeError after NVIDIA-NeMo#2178.

Adds a small auto-derive block right after PR NVIDIA-NeMo#2178's
`init_reference_model = ...` line: when the reference model is not
loaded, set `skip_reference_policy_logprobs_calculation=True` so the
sync/async training loops do not call `get_reference_policy_logprobs()`
on a non-existent reference model (issue NVIDIA-NeMo#1968 Bug 1).

The existing setup() assert (skip=True => kl_penalty must be 0) is
unchanged; together with this auto-derive, the bidirectional invariant
   kl_penalty == 0  <=>  ref model not loaded  <=>  skip ref logprobs
holds for any user-provided combination of the two flags.

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling added a commit to jinglinglingling/RL that referenced this pull request May 9, 2026
Adds a functional smoke test for the path enabled by PR NVIDIA-NeMo#2178 plus the
auto-skip safety net added in response to yuki-97's review:

> and I think it's better to add a functional test (or modify one
>  exist functional test) for reference_policy_kl_penalty == 0.

The test runs a 2-step GRPO with reference_policy_kl_penalty=0 and
without explicitly setting skip_reference_policy_logprobs_calculation,
then asserts:
  * the auto-skip log line fires (proves setup() override worked);
  * the existing "Reference policy logprob calculation will be skipped"
    confirmation log fires;
  * standard probs_ratio + gen_kl_error metric envelopes pass (PR NVIDIA-NeMo#2174
    zeros placeholder keeps loss math valid when KL penalty is zero).

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling added a commit to jinglinglingling/RL that referenced this pull request May 9, 2026
…ratio

Adds two parametrized unit tests in tests/unit/algorithms/test_grpo.py
that cover both grpo_train and async_grpo_train:

- test_grpo_train_skips_reference_policy_logprobs_when_configured:
  guards issue NVIDIA-NeMo#1968 / PRs NVIDIA-NeMo#2174, NVIDIA-NeMo#2178 by asserting that
  policy.get_reference_policy_logprobs is never called when
  grpo.skip_reference_policy_logprobs_calculation=True.

- test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratio:
  guards PR NVIDIA-NeMo#2177 by asserting that policy.get_logprobs is never
  called when loss_fn.force_on_policy_ratio=True.

Both tests reuse the existing mock_grpo_components fixture and the
mock_async_grpo_infrastructure helper so they require no GPU / Ray
cluster and run in CI in milliseconds (modulo cold-start import cost).

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling added a commit to jinglinglingling/RL that referenced this pull request May 9, 2026
The two regression tests added in this PR drive `grpo_train` /
`async_grpo_train` through code paths that call
`torch.zeros_like(prev_logprobs)` (PRs NVIDIA-NeMo#2174 / NVIDIA-NeMo#2178) and
`torch.zeros_like(generation_logprobs)` (PR NVIDIA-NeMo#2177). Under the bare
`mock_grpo_components` fixture those inputs are `MagicMock` objects, so
CI failed with `TypeError: zeros_like(): argument 'input' (position 1)
must be Tensor, not MagicMock` at `nemo_rl/algorithms/grpo.py:1801`.

Add a `_patched_logprob_phase` context manager that swaps in real
tensors for `policy.get_logprobs`, `policy.get_reference_policy_logprobs`,
and `batched_message_log_to_flat_message`, and use it in both the sync
and async branches of the two new tests.

Signed-off-by: Linglin Jing <linglinj@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented May 12, 2026

closes since included in #2443

@yuki-97 yuki-97 closed this May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[super-pr] skip loading ref model when kl>0

3 participants