Skip to content

Implement Tunix-based DPO/ORPO integration.#3668

Open
igorts-git wants to merge 3 commits into
mainfrom
igorts/dpo-feature-integration
Open

Implement Tunix-based DPO/ORPO integration.#3668
igorts-git wants to merge 3 commits into
mainfrom
igorts/dpo-feature-integration

Conversation

@igorts-git
Copy link
Copy Markdown
Collaborator

@igorts-git igorts-git commented Apr 15, 2026

Description

Tunix-based DPO/ORPO implementation.

This DPO implementation is based on train_sft.py. The hooks are shared with SFT (see #3862).

Only HF datasets are currently supported (no grain and TFDS) yet.

After this PR I plan to follow up with:

  • Run on a real dataset with eval to prove that it converges.
  • Additional documentation in the form of a Jupyter notebook and example running scripts.
  • Potentially code refactoring to share more functionality with SFT.
  • End-to-end tests.
  • Delete of the legacy DPO implementation.
  • Add support of grain and TFDS datasets.

FIXES: b/485626968

Tests

Ran DPO training on qwen2.5-1.5b, confirmed that the model. Performed qualitative validation tests decoded responses for targeted conceptual prompts:

  • Prompt: "What is DPO (Direct Preference Optimization)?"
    • Baseline Response: Outlined broad marketing behaviors, describing DPO as structural strategies centered on audience demographics.
    • DPO Model Response: Correctly aligned concept to mathematical post-training optimization, defining DPO as a target optimization algorithm, showing clear parameter update steer.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 15, 2026

@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch 2 times, most recently from 96406ac to 9e02c1d Compare April 15, 2026 20:29
@igorts-git igorts-git changed the title [WIP. DO NOT REVIEW YET] Implement Tunix-DPO/ORPO integration. Implement Tunix-based DPO/ORPO integration. Apr 15, 2026
@igorts-git igorts-git marked this pull request as ready for review April 15, 2026 22:12
@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch from 9e02c1d to 0d115b4 Compare April 15, 2026 22:25
return jnp.sum(batch["targets_segmentation"] != 0)


class DPOTrainingHooks(SFTTrainingHooks):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this class to src/maxtext/trainers/post_train/dpo/hooks.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



class DPOTrainingHooks(SFTTrainingHooks):
"""Training hooks for DPO.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can move common functionalities from SFTTrainingHooks and SFTDataHooks to maxtext/trainers/post_train/hooks.py and create child classes:

  1. SFTTrainingHooks and SFTDataHooks in post_train/sft that inherits
  2. DPOTrainingHooks and DPODataHooks in post_train/dpo that inherits
    This way, we could reuse those hooks implementation for RL hooks as well in future.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

use_dpo: true
train_data_columns: ['chosen', 'rejected']
eval_data_columns: ['chosen', 'rejected']
use_orpo: false
Copy link
Copy Markdown
Collaborator

@SurbhiJainUSC SurbhiJainUSC Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of use_orpo, can we have a nested config like

dpo:
   algo: dpo or orpo
   dpo_beta
   orpo_lambda
    ....

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea, but I would rather postpone this until a later PR. First I want to clean up all the legacy DPO code, which would come as a follow up PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is growing. Let's keep DPO specific changes in a new module input_pipeline/dpo_utils.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def map(self, element):
"""Maps original DPO columns to Tunix-compatible pre-tokenized format."""
# Handle the 'input' -> 'prompt_ids' mapping
prompt_ids = self._pad(element.pop("input"), self.max_prompt_length, left=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise KeyError exception when the required key is not found in element

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Handle the 'input' -> 'prompt_ids' mapping
prompt_ids = self._pad(element.pop("input"), self.max_prompt_length, left=True)
chosen_ids = self._pad(element.pop("chosen"), self.max_response_length, left=False)
rejected_ids = self._pad(element.pop("rejected"), self.max_response_length, left=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not hardcode the keys like chosen, rejected?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in fact I had a follow up PR for that and also to allow for datasets where there are only 2 columns: "chosen" and "rejected" where the "prompt" is inferred automatically by looking at the common prefix in "chosen" and "rejected". I cherry-picked that PR into this one.

return element


class DPOTunixPrep(grain.MapTransform):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use @dataclasses.dataclass to simplify the initialization

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (moved to a new file now)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a jupyter notebook for DPO/ORPO, which can run on Github CI to validate the implementation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added the notebook (vibe-coded), but I don't want to add it to Github CI yet. I still have a few follow up PRs to clean things up.

@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch 2 times, most recently from aca4891 to 3351239 Compare April 17, 2026 22:43
@shralex shralex mentioned this pull request Apr 18, 2026
@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch 2 times, most recently from 67f3373 to 1544c1b Compare April 20, 2026 17:08
@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch 6 times, most recently from 0f03a83 to b8b2171 Compare May 10, 2026 18:43
@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch 6 times, most recently from 714d6e9 to db68194 Compare May 11, 2026 03:50
@igorts-git igorts-git force-pushed the igorts/dpo-feature-integration branch from db68194 to 46a9ab3 Compare May 11, 2026 16:59
@igorts-git igorts-git requested a review from parambole as a code owner May 11, 2026 23:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants