-
Notifications
You must be signed in to change notification settings - Fork 514
Implement Tunix-based DPO/ORPO integration. #3668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
99e947e
46a9ab3
ab0e579
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| <!-- | ||
| Copyright 2026 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| --> | ||
|
|
||
| # Preference Optimization (DPO & ORPO) on Single-Host TPUs | ||
|
|
||
| MaxText supports two primary methods for aligning models with human preferences: **Direct Preference Optimization (DPO)** and **Odds Ratio Preference Optimization (ORPO)**. Both methods avoid the complexity of traditional Reinforcement Learning from Human Feedback (RLHF) by optimizing directly on preference data. | ||
|
|
||
| ## DPO vs. ORPO | ||
|
|
||
| - **Direct Preference Optimization (DPO):** Optimizes the policy by maximizing the relative log-probability of preferred responses over rejected ones. DPO requires a **reference model** (a frozen copy of the base model) to regularize the training and ensure the policy does not drift too far from the original model's distribution. | ||
| - **Odds Ratio Preference Optimization (ORPO):** A newer, reference-free alignment method that integrates the preference loss directly into the supervised fine-tuning objective using an odds ratio. Because it **does not require a reference model**, ORPO is more memory-efficient and faster than DPO. | ||
|
|
||
| ## Data Requirements | ||
|
|
||
| Both methods consume preference data in a **triplet format** consisting of a Prompt, a Chosen response, and a Rejected response. MaxText supports two ways to provide this data via the `train_data_columns` configuration: | ||
|
|
||
| 1. **Explicit Triplets (3 Columns):** The dataset provides three distinct columns for the prompt, chosen response, and rejected response. | ||
| 2. **Shared Prefix (2 Columns):** For datasets like `Anthropic/hh-rlhf`, where the prompt is embedded at the beginning of the responses, you can provide just two columns (e.g., `chosen` and `rejected`). MaxText will automatically extract the shared common prefix as the **Prompt** and treat the differing suffixes as the responses. | ||
|
|
||
| During the input pipeline, prompts are left-padded and responses are right-padded to maintain optimal context for the model. | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path. | ||
|
|
||
| ## Local run on a single-host TPU VM | ||
|
|
||
| ### Setup environment variables | ||
|
|
||
| Login to Hugging Face: | ||
|
|
||
| ```bash | ||
| hf auth login | ||
| ``` | ||
|
|
||
| Set up your training environment: | ||
|
|
||
| ```bash | ||
| # -- Model configuration -- | ||
| # The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a | ||
| # full list of supported models. | ||
| export MODEL=<MaxText Model> # e.g., "qwen3-0.6b" | ||
|
|
||
| # -- MaxText configuration -- | ||
| # Use a GCS bucket you own to store logs and checkpoints. Ideally in the same | ||
| # region as your TPUs to minimize latency and costs. | ||
| # You can list your buckets and their locations in the | ||
| # [Cloud Console](https://console.cloud.google.com/storage/browser). | ||
| export BASE_OUTPUT_DIRECTORY=<gcs bucket path> # e.g., gs://my-bucket/maxtext-runs | ||
|
|
||
| # An arbitrary string to identify this specific run. | ||
| # We recommend to include the model, user, and timestamp. | ||
| # Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods). | ||
| export RUN_NAME=<Name for this run> | ||
|
|
||
| export STEPS=<number of DPO steps to run> # e.g., 1000 | ||
| export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1 | ||
|
|
||
| export ALGORITHM=<"dpo" or "orpo"> # Set to either "orpo" or "dpo" | ||
|
|
||
| # -- Dataset configuration -- | ||
| export DATASET_NAME=<Hugging Face dataset name> # e.g., "argilla/distilabel-intel-orca-dpo-pairs" | ||
| export TRAIN_SPLIT=<data split for train> # e.g., train | ||
|
|
||
| # Map your dataset columns to [Prompt, Chosen, Rejected] | ||
| # For 3-column datasets: | ||
| export TRAIN_DATA_COLUMNS="['input', 'chosen', 'rejected']" | ||
|
|
||
| # For 2-column datasets (Prefix Extraction): | ||
| # export TRAIN_DATA_COLUMNS="['chosen', 'rejected']" | ||
| ``` | ||
|
|
||
| ## Running DPO Training | ||
|
|
||
| You can run the DPO training using the specialized post-training script: | ||
|
|
||
| ```{note} | ||
| The script below uses `eval_interval=0` because the default "argilla/distilabel-intel-orca-dpo-pairs" dataset only has a "train" split. | ||
| To use the same split for eval you can set a non-zero value and add `hf_eval_split=train`. | ||
| ``` | ||
|
|
||
| ```bash | ||
| python3 -m maxtext.trainers.post_train.dpo.train_dpo \ | ||
| run_name=${RUN_NAME?} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ | ||
| model_name=${MODEL?} \ | ||
| dataset_type=hf \ | ||
| hf_path=${DATASET_NAME?} \ | ||
| train_split=${TRAIN_SPLIT?} \ | ||
| train_data_columns="${TRAIN_DATA_COLUMNS?}" \ | ||
| steps=${STEPS?} \ | ||
| eval_interval=0 \ | ||
| per_device_batch_size=1 \ | ||
| max_target_length=1024 \ | ||
| use_dpo=$([ "${ALGORITHM?}" = "dpo" ] && echo 1 || echo 0) \ | ||
| use_orpo=$([ "${ALGORITHM?}" = "orpo" ] && echo 1 || echo 0) | ||
| ``` |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,9 +1,13 @@ | ||||
| base_config: "base.yml" | ||||
|
|
||||
| # To use ORPO, flip these flags. | ||||
| use_dpo: true | ||||
| use_orpo: false | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great idea, but I would rather postpone this until a later PR. First I want to clean up all the legacy DPO code, which would come as a follow up PR. |
||||
| orpo_lambda: 0.1 | ||||
|
|
||||
| packing: false | ||||
| train_data_columns: ['chosen', 'rejected'] | ||||
| eval_data_columns: ['chosen', 'rejected'] | ||||
| train_data_columns: ['input', 'chosen', 'rejected'] | ||||
| eval_data_columns: ['input', 'chosen', 'rejected'] | ||||
| base_output_directory: 'gs://maxtext-external/logs' | ||||
|
|
||||
| per_device_batch_size: 2.0 | ||||
|
|
@@ -12,16 +16,6 @@ max_target_length: 512 | |||
| eval_interval: 5 # test eval once, in the middle of 10 training steps | ||||
| eval_steps: 2 | ||||
|
|
||||
| # TFDS Pipeline ---------------------- | ||||
| dataset_type: tfds | ||||
| dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' | ||||
| dataset_name: 'tfds:1.0.0' | ||||
| eval_dataset_name: 'tfds:1.0.0' | ||||
| eval_split: 'test' | ||||
|
|
||||
| # HF Pipeline ------------------------- | ||||
| hf_eval_split: 'test' | ||||
|
|
||||
| gradient_clipping_threshold: 10.0 | ||||
| learning_rate: 5.0e-7 | ||||
| dpo_label_smoothing: 0.0 | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a jupyter notebook for DPO/ORPO, which can run on Github CI to validate the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added the notebook (vibe-coded), but I don't want to add it to Github CI yet. I still have a few follow up PRs to clean things up.