feat: Add linear CE loss fusion for DPO#2139
feat: Add linear CE loss fusion for DPO#2139pengdurice wants to merge 7 commits intoNVIDIA-NeMo:mainfrom
Conversation
…the loss values being nearly identical between base and exp. Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
📝 WalkthroughWalkthroughThis PR extends DPO training support for chunked linear cross-entropy fusion loss optimization. It adds documentation describing the feature, introduces Megatron configuration flags to control activation, threads the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unit/models/policy/test_megatron_worker.py (1)
1990-2056: Exercise the reference-policy logprob path in this agreement test.This test injects
reference_policy_logprobsand keepsinit_reference_model=False, so it only compares the actor-side fused loss. A regression in the newly wired reference-modelget_logprobs()path would still pass here. Please derive the reference logprobs from the policy, or add a small companion assertion that does.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/models/policy/test_megatron_worker.py` around lines 1990 - 2056, The test injects reference_policy_logprobs while creating Policy with init_reference_model=False, so the reference-model logprob path (Policy.get_logprobs / reference model wiring) is not exercised; fix by deriving reference_policy_logprobs from the policy under test (call Policy.get_logprobs or the actor/reference logprob method on policy_std/policy_fuse) instead of using torch.randn, or add a small assertion that compares the injected reference_policy_logprobs to policy.get_logprobs(...) output (using the same input_ids/attention_mask/token_mask) to ensure the reference-model logprob path is exercised; update the test around the reference_policy_logprobs creation and where Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs come from or are validated against the policy's get_logprobs method.examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml (1)
22-24: MoE-specific settings may be unnecessary for this model.
Qwen/Qwen2.5-Math-7Bis a dense (non-MoE) model. The MoE-related settings (freeze_moe_router,moe_router_bias_update_rate,moe_permute_fusion) will likely be ignored but add noise to the config. Consider removing them if they were copied from an MoE recipe.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml` around lines 22 - 24, Remove the unnecessary MoE-specific config keys from this dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and moe_permute_fusion in the llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence, replace with explicit comments or defaults in the recipe loader rather than leaving these MoE flags in the dense model config.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/guides/dpo.md`:
- Around line 212-220: Update the YAML snippet so enabling Megatron is explicit
and disables the other backend: add policy.dtensor_cfg.enabled: false alongside
the policy.megatron_cfg block (or clearly call out to set
policy.dtensor_cfg.enabled to false) so the final snippet contains
policy.megatron_cfg.enabled: true and policy.dtensor_cfg.enabled: false,
ensuring the switch activates Megatron cleanly without leaving dtensor enabled.
---
Nitpick comments:
In
`@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml`:
- Around line 22-24: Remove the unnecessary MoE-specific config keys from this
dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and
moe_permute_fusion in the
llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since
Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence,
replace with explicit comments or defaults in the recipe loader rather than
leaving these MoE flags in the dense model config.
In `@tests/unit/models/policy/test_megatron_worker.py`:
- Around line 1990-2056: The test injects reference_policy_logprobs while
creating Policy with init_reference_model=False, so the reference-model logprob
path (Policy.get_logprobs / reference model wiring) is not exercised; fix by
deriving reference_policy_logprobs from the policy under test (call
Policy.get_logprobs or the actor/reference logprob method on
policy_std/policy_fuse) instead of using torch.randn, or add a small assertion
that compares the injected reference_policy_logprobs to policy.get_logprobs(...)
output (using the same input_ids/attention_mask/token_mask) to ensure the
reference-model logprob path is exercised; update the test around the
reference_policy_logprobs creation and where
Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs come
from or are validated against the policy's get_logprobs method.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bcabae4a-06fe-4bd1-b2d7-e0e85ed7b178
📒 Files selected for processing (11)
docs/guides/dpo.mddocs/guides/sft.mdexamples/configs/dpo.yamlexamples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yamlnemo_rl/algorithms/dpo.pynemo_rl/algorithms/loss/loss_functions.pynemo_rl/models/megatron/train.pynemo_rl/models/policy/workers/megatron_policy_worker.pytests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.shtests/test_suites/nightly.txttests/unit/models/policy/test_megatron_worker.py
| Add the following to your Megatron config in your YAML file: | ||
|
|
||
| ```yaml | ||
| policy: | ||
| megatron_cfg: | ||
| enabled: true | ||
| use_linear_ce_fusion_loss: true | ||
| linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput | ||
| ``` |
There was a problem hiding this comment.
Show the backend switch in the YAML snippet.
examples/configs/dpo.yaml keeps policy.dtensor_cfg.enabled: true by default, so copying only this block can leave both backends enabled. Please include policy.dtensor_cfg.enabled: false here, or call it out explicitly, so the enablement instructions switch to Megatron cleanly.
✏️ Suggested doc fix
policy:
+ dtensor_cfg:
+ enabled: false
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/guides/dpo.md` around lines 212 - 220, Update the YAML snippet so
enabling Megatron is explicit and disables the other backend: add
policy.dtensor_cfg.enabled: false alongside the policy.megatron_cfg block (or
clearly call out to set policy.dtensor_cfg.enabled to false) so the final
snippet contains policy.megatron_cfg.enabled: true and
policy.dtensor_cfg.enabled: false, ensuring the switch activates Megatron
cleanly without leaving dtensor enabled.
There was a problem hiding this comment.
thanks @pengdurice , lgtm! have one small question.
and you'll need to rebase main since here's a conflict.
@terrykong could you also take a review when you have a chance?
| if self.use_linear_ce_fusion: | ||
| token_logprobs = output_tensor.to(torch.float32) | ||
| token_logprobs = token_logprobs[:, : original_seq_length - 1] |
There was a problem hiding this comment.
am I understand correctly that use_linear_ce_fusion here works for both sequence_packing is True or False?
There was a problem hiding this comment.
There was a problem hiding this comment.
oh I see, thanks for pointing this.
wdyt about adding an assert after that? so that if we support sequence packing in dpo, people can know we can't use sequence packing and loss fusion together.
if master_config["policy"]["sequence_packing"]["enabled"]:
assert xxx # assert not using loss fusionThere was a problem hiding this comment.
sounds a great idea! Added this guardrail and we can remove it once DPO + sequence packing + linear fusion loss all 3 are compatible.
7cd7789 to
7e4d725
Compare
Signed-off-by: pengdurice <pengduhit@gmail.com>
7e4d725 to
dccf08e
Compare
@yuki-97 , rebased and resolved conflict, thanks! |
…patible Signed-off-by: pengdurice <pengduhit@gmail.com>
What does this PR do ?
Suport Linear CE Loss Fusion for DPO
On top of #2036 where Linear CE loss fusion support is added for SFT. This PR adds the support to DPO loss.
Optimizations
Chunked Linear Cross-Entropy Fusion Loss
During standard DPO training the model materializes a full logit tensor of shape
[batch_size, seq_length, vocab_size](up to parallelism) for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The chunked linear cross-entropy fusion loss avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk.Benefits:
Issues
NA
Tests
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
Documentation
New Features
Tests