Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/guides/dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ The DPO implementation in NeMo RL supports several key parameters that can be ad

These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.

## Optimizations

### Chunked Linear Cross-Entropy Fusion Loss

During standard DPO training the model materializes a full logit tensor of shape `[batch_size, seq_length, vocab_size]` for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The **chunked linear cross-entropy fusion loss** avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk.

**Benefits:**

- Extends the maximum trainable sequence length significantly by eliminating the large logit tensor from GPU memory.
- Applies to both the training forward-backward pass and the reference model logprob computation.
- Produces numerically equivalent loss values to the standard path.

**How to enable:**

Add the following to your Megatron config in your YAML file:

```yaml
policy:
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput
```
Comment on lines +212 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Show the backend switch in the YAML snippet.

examples/configs/dpo.yaml keeps policy.dtensor_cfg.enabled: true by default, so copying only this block can leave both backends enabled. Please include policy.dtensor_cfg.enabled: false here, or call it out explicitly, so the enablement instructions switch to Megatron cleanly.

✏️ Suggested doc fix
 policy:
+  dtensor_cfg:
+    enabled: false
   megatron_cfg:
     enabled: true
     use_linear_ce_fusion_loss: true
     linear_ce_fusion_chunk_size: 256  # tokens per chunk; smaller = less memory, larger = more throughput
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/guides/dpo.md` around lines 212 - 220, Update the YAML snippet so
enabling Megatron is explicit and disables the other backend: add
policy.dtensor_cfg.enabled: false alongside the policy.megatron_cfg block (or
clearly call out to set policy.dtensor_cfg.enabled to false) so the final
snippet contains policy.megatron_cfg.enabled: true and
policy.dtensor_cfg.enabled: false, ensuring the switch activates Megatron
cleanly without leaving dtensor enabled.


**Notes:**

- Context parallelism is not supported when linear CE fusion is enabled.
- Sequence packing is not supported with DPO regardless of this setting (see [#719](https://github.com/NVIDIA-NeMo/RL/issues/719)).
- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point.

## Evaluate the Trained Model

Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.
2 changes: 1 addition & 1 deletion docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,6 @@ policy:

**Notes:**

- This optimization only applies to SFT training with `NLLLoss`. It does not affect other algorithms (GRPO, DPO, etc.).
- This optimization applies to SFT training with `NLLLoss` and DPO training. See the [DPO guide](dpo.md#chunked-linear-cross-entropy-fusion-loss) for DPO-specific details.
- Context parallelism is not supported when linear CE fusion is enabled.
- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point.
2 changes: 2 additions & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ policy:
## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: false
use_linear_ce_fusion_loss: false
linear_ce_fusion_chunk_size: 256
empty_unused_memory_level: 1
activation_checkpointing: false
tensor_model_parallel_size: 2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defaults: ../../dpo.yaml
dpo:
max_num_steps: 10
checkpointing:
enabled: false
policy:
model_name: Qwen/Qwen2.5-Math-7B
tokenizer:
name: ${policy.model_name}
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 6000
dtensor_cfg:
enabled: false
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 128
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 2
attention_backend: unfused
freeze_moe_router: true
moe_router_bias_update_rate: 0.0
moe_permute_fusion: true
optimizer:
lr: 1.0e-06
min_lr: 1.0e-06
adam_beta2: 0.999
use_distributed_optimizer: false
use_precision_aware_optimizer: false
scheduler:
lr_warmup_iters: 10
lr_warmup_init: 1.0e-11
lr_decay_iters: 32
make_sequence_length_divisible_by: 8
data:
num_workers: 8
logger:
wandb:
project: nemo-rl
name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
tensorboard:
log_dir: tb_logs-dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
mlflow:
run_name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
cluster:
gpus_per_node: 8
19 changes: 17 additions & 2 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,21 @@ def setup(
"See https://github.com/NVIDIA-NeMo/RL/issues/719"
)

policy_config = master_config["policy"]
# Add a guardrail for linear CE fusion loss: if sequence packing is enabled for DPO in the future,
# we need to validate the fusion path with cu_seqlens-based logprob aggregation first and then remove this guardrail.
if policy_config["sequence_packing"]["enabled"]:
assert not (
policy_config["megatron_cfg"]["enabled"]
and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"]
), (
"Linear CE fusion loss is not supported with sequence packing in DPO. "
"The fusion path has not been validated with cu_seqlens-based logprob aggregation."
)

set_seed(master_config["dpo"]["seed"])

# Extract individual configs for easier access
policy_config = master_config["policy"]
data_config = master_config["data"]
logger_config = master_config["logger"]
cluster_config = master_config["cluster"]
Expand Down Expand Up @@ -249,7 +260,11 @@ def setup(
# print the node IP and GPU ID of the policy workers for debugging
policy.print_node_ip_and_gpu_id()

loss_fn = DPOLossFn(master_config["dpo"])
loss_fn = DPOLossFn(
master_config["dpo"],
use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"]
and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"],
)
print(" ✓ Model initialized")

print("\n" + "=" * 60)
Expand Down
5 changes: 3 additions & 2 deletions nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,14 @@ class DPOLossFn(PreferenceLossFn):
loss_type = LossType.SEQUENCE_LEVEL
input_type = LossInputType.LOGPROB

def __init__(self, cfg: DPOLossConfig):
def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False):
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.preference_loss_weight = cfg["preference_loss_weight"]
self.sft_loss_weight = cfg["sft_loss_weight"]
self.preference_average_log_probs = cfg["preference_average_log_probs"]
self.sft_average_log_probs = cfg["sft_average_log_probs"]
self.sft_loss = NLLLossFn()
self.use_linear_ce_fusion = use_linear_ce_fusion
self.sft_loss = NLLLossFn(use_linear_ce_fusion=use_linear_ce_fusion)

def _dpo_loss(
self,
Expand Down
16 changes: 12 additions & 4 deletions nemo_rl/models/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,11 @@ def __init__(
self,
cfg: PolicyConfig,
sampling_params: Optional[TrainingSamplingParams] = None,
use_linear_ce_fusion: bool = False,
):
self.cfg = cfg
self.sampling_params = sampling_params
self.use_linear_ce_fusion = use_linear_ce_fusion

def __call__(
self,
Expand All @@ -427,10 +429,13 @@ def __call__(
original_seq_length = unpacked_input_ids.shape[1]

def processor_fn_inner(output_tensor):
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
if self.cfg["sequence_packing"]["enabled"]:
if self.use_linear_ce_fusion:
token_logprobs = output_tensor.to(torch.float32)
token_logprobs = token_logprobs[:, : original_seq_length - 1]
Comment on lines +432 to +434
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I understand correctly that use_linear_ce_fusion here works for both sequence_packing is True or False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yuki-97 , thank you for your question! Nope, if I understand correctly, currently, DPO and sequence packing are mutually exclusive:

assert not master_config["policy"]["sequence_packing"]["enabled"], (
So here, we first check if loss fusion is enabled, if so, go to that path, elif and else will go to original path. LMK if that makes sense;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, thanks for pointing this.

wdyt about adding an assert after that? so that if we support sequence packing in dpo, people can know we can't use sequence packing and loss fusion together.

if master_config["policy"]["sequence_packing"]["enabled"]:
    assert xxx  # assert not using loss fusion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds a great idea! Added this guardrail and we can remove it once DPO + sequence packing + linear fusion loss all 3 are compatible.

elif self.cfg["sequence_packing"]["enabled"]:
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
token_logprobs = from_parallel_logits_to_logprobs_packed_sequences(
output_tensor,
target=input_ids,
Expand All @@ -445,6 +450,9 @@ def processor_fn_inner(output_tensor):
sampling_params=self.sampling_params,
)
else:
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
token_logprobs = from_parallel_logits_to_logprobs(
output_tensor,
target=unpacked_input_ids,
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,13 @@ def get_logprobs(
straggler_timer=self.mcore_state.straggler_timer,
)

use_linear_ce_fusion = self.cfg["megatron_cfg"].get(
"use_linear_ce_fusion_loss", False
)
logprobs_post_processor = LogprobsPostProcessor(
cfg=self.cfg,
sampling_params=self.sampling_params,
use_linear_ce_fusion=use_linear_ce_fusion,
)

list_of_logprobs = megatron_forward_backward(
Expand All @@ -506,6 +510,7 @@ def get_logprobs(
defer_fp32_logits=self.defer_fp32_logits,
sampling_params=self.sampling_params,
straggler_timer=self.mcore_state.straggler_timer,
use_linear_ce_fusion_loss=use_linear_ce_fusion,
)

if is_pipeline_last_stage(ignore_virtual=True):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=10
MAX_STEPS=10
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=25
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_dpo.py \
--config $CONFIG_PATH \
dpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=true \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
# Smoke checks: run completed and loss is finite/reasonable.
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["10"] > 0.0' \
'data["train/loss"]["10"] < 20.0'

# Clean up checkpoint directory after successful run to save space.
rm -rf "$CKPT_DIR"
fi
1 change: 1 addition & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh
tests/test_suites/llm/sft-qwen2.5-math7b-2n8g-megatron.sh
# chunked linear CE loss
tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh
tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh

# Nemotron Super 49B SFT tests
# Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571
Expand Down
102 changes: 102 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,108 @@ def test_megatron_sft_linear_ce_fusion_agreement(tiny_qwen2_model_path):
torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2)


@pytest.mark.timeout(600)
def test_megatron_dpo_linear_ce_fusion_agreement(tiny_qwen2_model_path):
"""Test that linear CE fusion loss produces the same results as the standard path for DPO."""
import time

num_gpus = 2
batch_size = 4
seq_len = 64
vocab_size = 151936

torch.manual_seed(42)
input_ids = torch.randint(0, vocab_size, (batch_size * 2, seq_len))
attention_mask = torch.ones(batch_size * 2, seq_len)
input_lengths = attention_mask.sum(dim=1).to(torch.int32)
token_mask = torch.triu(torch.ones(batch_size * 2, seq_len), diagonal=1)
sample_mask = torch.ones(batch_size * 2)
reference_policy_logprobs = torch.randn(batch_size * 2, seq_len)

data = BatchedDataDict(
{
"input_ids": input_ids,
"input_lengths": input_lengths,
"attention_mask": attention_mask,
"token_mask": token_mask,
"sample_mask": sample_mask,
"reference_policy_logprobs": reference_policy_logprobs,
}
)

dpo_cfg = {
"reference_policy_kl_penalty": 0.1,
"preference_loss_weight": 1.0,
"sft_loss_weight": 0.5,
"preference_average_log_probs": False,
"sft_average_log_probs": False,
}

# --- Standard DPO (no linear CE fusion) ---
cluster_std = RayVirtualCluster(
name="test-dpo-std",
bundle_ct_per_node_list=[num_gpus],
use_gpus=True,
num_gpus_per_node=num_gpus,
max_colocated_worker_groups=1,
)
config_std = create_megatron_test_config(tiny_qwen2_model_path)
tokenizer = get_tokenizer(config_std["tokenizer"])
policy_std = Policy(
cluster=cluster_std,
config=config_std,
tokenizer=tokenizer,
init_reference_model=False,
)
dpo_loss_std = DPOLossFn(dpo_cfg)

try:
policy_std.prepare_for_training()
results_std = policy_std.train(data, dpo_loss_std)
loss_std = results_std["loss"]
finally:
policy_std.shutdown()
cluster_std.shutdown()

time.sleep(10)

# --- DPO with linear CE fusion ---
cluster_fuse = RayVirtualCluster(
name="test-dpo-fuse",
bundle_ct_per_node_list=[num_gpus],
use_gpus=True,
num_gpus_per_node=num_gpus,
max_colocated_worker_groups=1,
)
config_fuse = create_megatron_test_config(tiny_qwen2_model_path)
config_fuse["megatron_cfg"]["use_linear_ce_fusion_loss"] = True
config_fuse["megatron_cfg"]["linear_ce_fusion_chunk_size"] = 256
policy_fuse = Policy(
cluster=cluster_fuse,
config=config_fuse,
tokenizer=tokenizer,
init_reference_model=False,
)
dpo_loss_fuse = DPOLossFn(dpo_cfg, use_linear_ce_fusion=True)

try:
policy_fuse.prepare_for_training()
results_fuse = policy_fuse.train(data, dpo_loss_fuse)
loss_fuse = results_fuse["loss"]
finally:
policy_fuse.shutdown()
cluster_fuse.shutdown()

# Verify both produce valid losses
assert not torch.isnan(loss_std).any(), "Standard DPO loss should not be NaN"
assert not torch.isnan(loss_fuse).any(), "Fusion DPO loss should not be NaN"
assert not torch.isinf(loss_std).any(), "Standard DPO loss should not be Inf"
assert not torch.isinf(loss_fuse).any(), "Fusion DPO loss should not be Inf"

# Verify losses are numerically close
torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2)


@pytest.mark.skip(
reason="transformers-v5: Ray ActorAlreadyExistsError (megatron actor cleanup issue)"
)
Expand Down