Skip to content

Added SFT Pre-Processing for Grain Input Pipeline#3437

Open
ajkv-google wants to merge 8 commits intomainfrom
ajkv/sft-grain-implementation
Open

Added SFT Pre-Processing for Grain Input Pipeline#3437
ajkv-google wants to merge 8 commits intomainfrom
ajkv/sft-grain-implementation

Conversation

@ajkv-google
Copy link
Collaborator

@ajkv-google ajkv-google commented Mar 18, 2026

Description

This PR introduces SFT support to the Grain input pipeline by adding a separate sft_preprocessing_pipeline function. Rather than cluttering the existing pretrain code, it uses simple conditionals inside the train and eval iterators to route to this new SFT logic. I followed the existing Hugging Face SFT implementation and adapted its logic to be compatible with Grain's element-wise datasets.

Tests

I added a unit test to verify end-to-end functionality to make sure the Grain SFT pipeline formats the data and outputs correctly. Ran this command to execute the unit test:

  • pytest tests/unit/grain_data_processing_test.py::GrainSFTParquetProcessingTest -v

This is the output of the test: Test Passed Output

Also, ran the training pipeline in Maxtext with sft enabled using a grain dataset with this command:

  • python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=test_grain_sft dataset_type=grain grain_file_type=parquet grain_train_files=gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet steps=10 tokenizer_type=huggingface tokenizer_path=HuggingFaceH4/zephyr-7b-beta

Verified that the sft processing changes worked and trained successfully: Logs

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 40.57971% with 41 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/maxtext/input_pipeline/grain_data_processing.py 40.57% 33 Missing and 8 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be great to test not only with maxtext general sft but with distillation sft pipeline as well

messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}]
elif set(data_columns) == {"question", "answer"}:
messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF pipeline asserts sft is running on a conversational format

dataset = dataset.batch(batch_size, batch_fn=batch_fn)

# Shift inputs for teacher-forced training
dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it alway be executed in a generic sft_preprocessing_pipeline() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted the shifting logic into a generic shift_dataset() helper in data_processing_utils.py and applied it uniformly across both the SFT and Pretrain pipelines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my concern was that the comment suggests it is distillation only logic but it is applied always.
What is the meaning of that shift operation? should it only be applied in distillation pipeline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah that comment is a bit misleading. I just meant standard next-token prediction, not distillation. From my understanding, the 1-token shift is required for all autoregressive training such as pretraining and sft to align inputs and targets, and it can be applied in distillation as well.

), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}"

dataset = dataset.map(
functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do support the conversion. Because Grain's .map() processes row-by-row (unlike HF dataset operations), I implemented the conversion inside _format_chat_template_grain() above this line. It handles both prompt/completion and question/answer pairs.

elif tokenizer_model.unk_id is not None:
pad_id = tokenizer_model.unk_id
else:
pad_id = -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 0 as the deafult is better

return batch_size


def pack_or_pad_and_batch_dataset(dataset, config, batch_size, pad_id, data_columns, tokenizer_model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some simpler name?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants