Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_jupyter_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
# Run Hugging Face authentication
hf auth login --token "$HF_TOKEN"

for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl,dpo,orpo}*.ipynb; do
filename=$(basename "$notebook")
# TODO: Update runnner to v6e-8 as RL with LLama3.1-8b doesn't fit on v6e-4
if [[ "$filename" == "sft_llama3_demo_gpu.ipynb" || "$filename" == "maxtext_with_gepa.ipynb" ]]; then
Expand Down
5 changes: 5 additions & 0 deletions docs/guides/run_python_notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). This notebook is friendly for beginners and runs successfully on Google Colab's free-tier v5e-1 TPU runtime.
- **`sft_llama3_demo_tpu.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k). We recommend running this on a v5p-8 TPU VM using [Method 2](#method-2-visual-studio-code-with-tpu-recommended) or [Method 3](#method-3-local-jupyter-lab-with-tpu-recommended).

### Preference Optimization (DPO & ORPO) Training

- **`dpo_qwen3_demo.ipynb`** → Direct Preference Optimization (DPO) training on [Hugging Face argilla/distilabel-intel-orca-dpo-pairs dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs). Friendly for beginners and runs successfully on single-host TPU environments.
- **`orpo_qwen3_demo.ipynb`** → Odds Ratio Preference Optimization (ORPO) training on [Hugging Face argilla/distilabel-intel-orca-dpo-pairs dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs). Friendly for beginners and runs successfully on single-host TPU environments.

### Reinforcement Learning (GRPO/GSPO) Training

- **`rl_llama3_demo.ipynb`** → GRPO/GSPO training on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). We recommend running this on a v5p-8 TPU VM using [Method 2](#method-2-visual-studio-code-with-tpu-recommended) or [Method 3](#method-3-local-jupyter-lab-with-tpu-recommended).
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/post_training_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ MaxText was co-designed with key Google led innovations to provide a unified pos
- **SFT (Supervised Fine-Tuning)**
- [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html)
- [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html)
- **DPO (Direct Preference Optimization) and ORPO (Odds-Ratio Policy Optimization)**
- [DPO/ORPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/dpo.html)
- **Multimodal SFT**
- [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html)
- **Reinforcement Learning (RL)**
Expand Down Expand Up @@ -65,6 +67,7 @@ maxdepth: 1
---
posttraining/sft.md
posttraining/sft_on_multi_host.md
posttraining/dpo.md
posttraining/rl.md
posttraining/rl_on_multi_host.md
posttraining/knowledge_distillation.md
Expand Down
110 changes: 110 additions & 0 deletions docs/tutorials/posttraining/dpo.md
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a jupyter notebook for DPO/ORPO, which can run on Github CI to validate the implementation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added the notebook (vibe-coded), but I don't want to add it to Github CI yet. I still have a few follow up PRs to clean things up.

Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
<!--
Copyright 2026 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Preference Optimization (DPO & ORPO) on Single-Host TPUs

MaxText supports two primary methods for aligning models with human preferences: **Direct Preference Optimization (DPO)** and **Odds Ratio Preference Optimization (ORPO)**. Both methods avoid the complexity of traditional Reinforcement Learning from Human Feedback (RLHF) by optimizing directly on preference data.

## DPO vs. ORPO

- **Direct Preference Optimization (DPO):** Optimizes the policy by maximizing the relative log-probability of preferred responses over rejected ones. DPO requires a **reference model** (a frozen copy of the base model) to regularize the training and ensure the policy does not drift too far from the original model's distribution.
- **Odds Ratio Preference Optimization (ORPO):** A newer, reference-free alignment method that integrates the preference loss directly into the supervised fine-tuning objective using an odds ratio. Because it **does not require a reference model**, ORPO is more memory-efficient and faster than DPO.

## Data Requirements

Both methods consume preference data in a **triplet format** consisting of a Prompt, a Chosen response, and a Rejected response. MaxText supports two ways to provide this data via the `train_data_columns` configuration:

1. **Explicit Triplets (3 Columns):** The dataset provides three distinct columns for the prompt, chosen response, and rejected response.
2. **Shared Prefix (2 Columns):** For datasets like `Anthropic/hh-rlhf`, where the prompt is embedded at the beginning of the responses, you can provide just two columns (e.g., `chosen` and `rejected`). MaxText will automatically extract the shared common prefix as the **Prompt** and treat the differing suffixes as the responses.

During the input pipeline, prompts are left-padded and responses are right-padded to maintain optimal context for the model.

## Prerequisites

For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path.

## Local run on a single-host TPU VM

### Setup environment variables

Login to Hugging Face:

```bash
hf auth login
```

Set up your training environment:

```bash
# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MaxText Model> # e.g., "qwen3-0.6b"

# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser).
export BASE_OUTPUT_DIRECTORY=<gcs bucket path> # e.g., gs://my-bucket/maxtext-runs

# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<Name for this run>

export STEPS=<number of DPO steps to run> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1

export ALGORITHM=<"dpo" or "orpo"> # Set to either "orpo" or "dpo"

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face dataset name> # e.g., "argilla/distilabel-intel-orca-dpo-pairs"
export TRAIN_SPLIT=<data split for train> # e.g., train

# Map your dataset columns to [Prompt, Chosen, Rejected]
# For 3-column datasets:
export TRAIN_DATA_COLUMNS="['input', 'chosen', 'rejected']"

# For 2-column datasets (Prefix Extraction):
# export TRAIN_DATA_COLUMNS="['chosen', 'rejected']"
```

## Running DPO Training

You can run the DPO training using the specialized post-training script:

```{note}
The script below uses `eval_interval=0` because the default "argilla/distilabel-intel-orca-dpo-pairs" dataset only has a "train" split.
To use the same split for eval you can set a non-zero value and add `hf_eval_split=train`.
```

```bash
python3 -m maxtext.trainers.post_train.dpo.train_dpo \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=${MODEL?} \
dataset_type=hf \
hf_path=${DATASET_NAME?} \
train_split=${TRAIN_SPLIT?} \
train_data_columns="${TRAIN_DATA_COLUMNS?}" \
steps=${STEPS?} \
eval_interval=0 \
per_device_batch_size=1 \
max_target_length=1024 \
use_dpo=$([ "${ALGORITHM?}" = "dpo" ] && echo 1 || echo 0) \
use_orpo=$([ "${ALGORITHM?}" = "orpo" ] && echo 1 || echo 0)
```
19 changes: 9 additions & 10 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -180,9 +180,10 @@ def _log_training_metrics(self, metrics, step):
f"perplexity: {perplexity:.3f}",
]
)
if self.config.use_dpo:
dpo_loss = scalars.get("learning/dpo_loss", 0.0)
log_parts.append(f"dpo_loss: {dpo_loss:.3f}")
if "learning/dpo_loss" in scalars:
log_parts.append(f"dpo_loss: {scalars['learning/dpo_loss']:.3f}")
if "learning/reward_accuracy" in scalars:
log_parts.append(f"reward_accuracy: {scalars['learning/reward_accuracy']:.3f}")

if self.config.num_experts > 1:
moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0)
Expand Down Expand Up @@ -374,9 +375,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] += float(metrics["scalar"].get("evaluation/z_loss", 0.0))
if self.config.use_dpo:
if "evaluation/dpo_reward_accuracy" in metrics["scalar"]:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float(
metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)
metrics["scalar"]["evaluation/dpo_reward_accuracy"]
)

if eval_step_count:
Expand All @@ -400,10 +401,8 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
self.cumulative_eval_metrics["scalar"]["eval/avg_z_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] / eval_step_count
)
if self.config.use_dpo:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = (
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] / eval_step_count
)
if "eval/dpo_reward_accuracy" in self.cumulative_eval_metrics["scalar"]:
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] /= eval_step_count

self.write_metrics(self.cumulative_eval_metrics, step, is_training=False)

Expand Down
9 changes: 5 additions & 4 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ checkpoint_conversion_fn: none
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
source_checkpoint_layout: "orbax"

# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
colocated_python_checkpointing: False

# enables autocheckpoint, which saves a checkpoint at the preemption step.
Expand Down Expand Up @@ -449,7 +449,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
# internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc.
internal_compile: False
internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile.
compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536"
compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536"

# Parallelism
shard_mode: "auto" # can be either auto or explicit
Expand Down Expand Up @@ -678,6 +678,7 @@ global_rampup_samples: 500

# direct preference optimization (DPO)
use_dpo: False
use_orpo: False
dpo_label_smoothing: 0.0
dpo_beta: 0.1

Expand Down Expand Up @@ -1204,7 +1205,7 @@ use_jax_splash: false
# Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter)
vllm_hf_config_path: ""
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
# This can be used to override specific settings without modifying the original config file.
# This can be used to override specific settings without modifying the original config file.
vllm_hf_overrides: {}
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
Expand All @@ -1219,7 +1220,7 @@ sinkhorn_iterations: 20

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
# Example: [1, 4] attaches to the 2nd and 5th layer.
# Example: [1, 4] attaches to the 2nd and 5th layer.
engram_layers: []
# The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams.
engram_max_ngram_size: 3
Expand Down
18 changes: 6 additions & 12 deletions src/maxtext/configs/post_train/dpo.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
base_config: "base.yml"

# To use ORPO, flip these flags.
use_dpo: true
use_orpo: false
Copy link
Copy Markdown
Collaborator

@SurbhiJainUSC SurbhiJainUSC Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of use_orpo, can we have a nested config like

dpo:
   algo: dpo or orpo
   dpo_beta
   orpo_lambda
    ....

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea, but I would rather postpone this until a later PR. First I want to clean up all the legacy DPO code, which would come as a follow up PR.

orpo_lambda: 0.1

packing: false
train_data_columns: ['chosen', 'rejected']
eval_data_columns: ['chosen', 'rejected']
train_data_columns: ['input', 'chosen', 'rejected']
eval_data_columns: ['input', 'chosen', 'rejected']
base_output_directory: 'gs://maxtext-external/logs'

per_device_batch_size: 2.0
Expand All @@ -12,16 +16,6 @@ max_target_length: 512
eval_interval: 5 # test eval once, in the middle of 10 training steps
eval_steps: 2

# TFDS Pipeline ----------------------
dataset_type: tfds
dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
dataset_name: 'tfds:1.0.0'
eval_dataset_name: 'tfds:1.0.0'
eval_split: 'test'

# HF Pipeline -------------------------
hf_eval_split: 'test'

gradient_clipping_threshold: 10.0
learning_rate: 5.0e-7
dpo_label_smoothing: 0.0
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"maxtext.trainers.pre_train.train": "base.yml",
"maxtext.trainers.pre_train.train_compile": "base.yml",
"maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
"maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml",
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
Expand Down
9 changes: 8 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,11 +1187,13 @@ class OlmoGrainDataset(BaseModel):


class FineTuning(BaseModel):
"""Configuration for fine-tuning methods like DPO, SFT, and GRPO."""
"""Configuration for fine-tuning methods like DPO, ORPO, SFT, and GRPO."""

use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.")
use_orpo: bool = Field(False, description="If True, enables Odds Ratio Preference Optimization training.")
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
dpo_beta: float = Field(0.1, description="Beta parameter for DPO.")
orpo_lambda: float = Field(0.1, description="Weight for preference loss (ORPO only)")
use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.")
sft_train_on_completion_only: bool = Field(
False, description="If True, trains only on the completion part of the text."
Expand Down Expand Up @@ -2819,6 +2821,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("For multimodal SFT, `sft_train_on_completion_only` must be True.")
if self.packing:
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
if self.use_dpo or self.use_orpo:
if self.packing:
raise ValueError("For DPO/ORPO, `packing` is not supported.")
if sum([self.use_sft, self.use_dpo, self.use_orpo]) > 1:
raise ValueError("Only one of `use_sft`, `use_dpo`, or `use_orpo` can be True.")
if self.shard_mode == ShardMode.EXPLICIT:
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
if self.decoder_block.value not in supported_decoders:
Expand Down
Loading
Loading