Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def init_train_dataloader(dataset, suffix: str = ""):
policy_config["megatron_cfg"]["train_iters"] = total_train_iters

# Define initialization functions that will be used in all paths
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs.

and I think it's better to add a functional test (or modify one exist functional test) for reference_policy_kl_penalty == 0.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BUG: Setting init_reference_model=False here prevents reference model weights from being loaded, but the sync training loop (line 1754) still calls policy.get_reference_policy_logprobs() unless grpo.skip_reference_policy_logprobs_calculation is explicitly True.

When reference_policy_kl_penalty=0 and the skip flag is unset, use_reference_model() accesses self.reference_model_state_dict which was never initialized → AttributeError.

Multiple existing configs are affected:

  • examples/nemo_gym/grpo_nanov3.yaml
  • examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml
  • examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml
  • examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml

All have reference_policy_kl_penalty: 0 without setting skip_reference_policy_logprobs_calculation: true.

Suggested fix — auto-derive the skip flag:

Suggested change
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0
# Auto-skip reference logprob calculation when reference model is not loaded
if not init_reference_model:
master_config["grpo"]["skip_reference_policy_logprobs_calculation"] = True


def init_policy():
"""Initialize policy training workers."""
t0 = time.perf_counter()
Expand All @@ -565,6 +567,7 @@ def init_policy():
weights_path=weights_path,
optimizer_path=optimizer_path,
init_optimizer=True,
init_reference_model=init_reference_model,
)
return p, time.perf_counter() - t0

Expand Down
Loading