Added SFT Pre-Processing for Grain Input Pipeline#3437
Added SFT Pre-Processing for Grain Input Pipeline#3437ajkv-google wants to merge 8 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
vlad-karp
left a comment
There was a problem hiding this comment.
It would also be great to test not only with maxtext general sft but with distillation sft pipeline as well
| messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] | ||
| elif set(data_columns) == {"question", "answer"}: | ||
| messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] | ||
| else: |
There was a problem hiding this comment.
HF pipeline asserts sft is running on a conversational format
| dataset = dataset.batch(batch_size, batch_fn=batch_fn) | ||
|
|
||
| # Shift inputs for teacher-forced training | ||
| dataset = dataset.map( |
There was a problem hiding this comment.
should it alway be executed in a generic sft_preprocessing_pipeline() ?
There was a problem hiding this comment.
I extracted the shifting logic into a generic shift_dataset() helper in data_processing_utils.py and applied it uniformly across both the SFT and Pretrain pipelines.
There was a problem hiding this comment.
my concern was that the comment suggests it is distillation only logic but it is applied always.
What is the meaning of that shift operation? should it only be applied in distillation pipeline?
There was a problem hiding this comment.
I see, yeah that comment is a bit misleading. I just meant standard next-token prediction, not distillation. From my understanding, the 1-token shift is required for all autoregressive training such as pretraining and sft to align inputs and targets, and it can be applied in distillation as well.
| ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" | ||
|
|
||
| dataset = dataset.map( | ||
| functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) |
There was a problem hiding this comment.
The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261
There was a problem hiding this comment.
Yes, we do support the conversion. Because Grain's .map() processes row-by-row (unlike HF dataset operations), I implemented the conversion inside _format_chat_template_grain() above this line. It handles both prompt/completion and question/answer pairs.
| elif tokenizer_model.unk_id is not None: | ||
| pad_id = tokenizer_model.unk_id | ||
| else: | ||
| pad_id = -1 |
There was a problem hiding this comment.
I think 0 as the deafult is better
| return batch_size | ||
|
|
||
|
|
||
| def pack_or_pad_and_batch_dataset(dataset, config, batch_size, pad_id, data_columns, tokenizer_model): |
Description
This PR introduces SFT support to the Grain input pipeline by adding a separate
sft_preprocessing_pipelinefunction. Rather than cluttering the existing pretrain code, it uses simple conditionals inside the train and eval iterators to route to this new SFT logic. I followed the existing Hugging Face SFT implementation and adapted its logic to be compatible with Grain's element-wise datasets.Tests
I added a unit test to verify end-to-end functionality to make sure the Grain SFT pipeline formats the data and outputs correctly. Ran this command to execute the unit test:
pytest tests/unit/grain_data_processing_test.py::GrainSFTParquetProcessingTest -vThis is the output of the test: Test Passed Output
Also, ran the training pipeline in Maxtext with sft enabled using a grain dataset with this command:
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=test_grain_sft dataset_type=grain grain_file_type=parquet grain_train_files=gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet steps=10 tokenizer_type=huggingface tokenizer_path=HuggingFaceH4/zephyr-7b-betaVerified that the sft processing changes worked and trained successfully: Logs
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.