Implement Tunix-based DPO/ORPO integration.#3668
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
96406ac to
9e02c1d
Compare
9e02c1d to
0d115b4
Compare
| return jnp.sum(batch["targets_segmentation"] != 0) | ||
|
|
||
|
|
||
| class DPOTrainingHooks(SFTTrainingHooks): |
There was a problem hiding this comment.
Please move this class to src/maxtext/trainers/post_train/dpo/hooks.py
|
|
||
|
|
||
| class DPOTrainingHooks(SFTTrainingHooks): | ||
| """Training hooks for DPO. |
There was a problem hiding this comment.
Also, we can move common functionalities from SFTTrainingHooks and SFTDataHooks to maxtext/trainers/post_train/hooks.py and create child classes:
SFTTrainingHooksandSFTDataHooksinpost_train/sftthat inheritsDPOTrainingHooksandDPODataHooksinpost_train/dpothat inherits
This way, we could reuse those hooks implementation for RL hooks as well in future.
| use_dpo: true | ||
| train_data_columns: ['chosen', 'rejected'] | ||
| eval_data_columns: ['chosen', 'rejected'] | ||
| use_orpo: false |
There was a problem hiding this comment.
Instead of use_orpo, can we have a nested config like
dpo:
algo: dpo or orpo
dpo_beta
orpo_lambda
....
There was a problem hiding this comment.
This is a great idea, but I would rather postpone this until a later PR. First I want to clean up all the legacy DPO code, which would come as a follow up PR.
There was a problem hiding this comment.
This file is growing. Let's keep DPO specific changes in a new module input_pipeline/dpo_utils.py
| def map(self, element): | ||
| """Maps original DPO columns to Tunix-compatible pre-tokenized format.""" | ||
| # Handle the 'input' -> 'prompt_ids' mapping | ||
| prompt_ids = self._pad(element.pop("input"), self.max_prompt_length, left=True) |
There was a problem hiding this comment.
Please raise KeyError exception when the required key is not found in element
| # Handle the 'input' -> 'prompt_ids' mapping | ||
| prompt_ids = self._pad(element.pop("input"), self.max_prompt_length, left=True) | ||
| chosen_ids = self._pad(element.pop("chosen"), self.max_response_length, left=False) | ||
| rejected_ids = self._pad(element.pop("rejected"), self.max_response_length, left=False) |
There was a problem hiding this comment.
Can we not hardcode the keys like chosen, rejected?
There was a problem hiding this comment.
Yes, in fact I had a follow up PR for that and also to allow for datasets where there are only 2 columns: "chosen" and "rejected" where the "prompt" is inferred automatically by looking at the common prefix in "chosen" and "rejected". I cherry-picked that PR into this one.
| return element | ||
|
|
||
|
|
||
| class DPOTunixPrep(grain.MapTransform): |
There was a problem hiding this comment.
Use @dataclasses.dataclass to simplify the initialization
There was a problem hiding this comment.
done (moved to a new file now)
There was a problem hiding this comment.
Can we also add a jupyter notebook for DPO/ORPO, which can run on Github CI to validate the implementation.
There was a problem hiding this comment.
I just added the notebook (vibe-coded), but I don't want to add it to Github CI yet. I still have a few follow up PRs to clean things up.
aca4891 to
3351239
Compare
67f3373 to
1544c1b
Compare
0f03a83 to
b8b2171
Compare
714d6e9 to
db68194
Compare
db68194 to
46a9ab3
Compare
…I validation loop
Description
Tunix-based DPO/ORPO implementation.
This DPO implementation is based on
train_sft.py. The hooks are shared with SFT (see #3862).Only HF datasets are currently supported (no grain and TFDS) yet.
After this PR I plan to follow up with:
FIXES: b/485626968
Tests
Ran DPO training on
qwen2.5-1.5b, confirmed that the model. Performed qualitative validation tests decoded responses for targeted conceptual prompts:"What is DPO (Direct Preference Optimization)?"Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.