Skip to content

Enable Qwen3-Omni SFT on ChartQA#3863

Draft
hengtaoguo wants to merge 2 commits into
mainfrom
hengtaoguo-omni-sft
Draft

Enable Qwen3-Omni SFT on ChartQA#3863
hengtaoguo wants to merge 2 commits into
mainfrom
hengtaoguo-omni-sft

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented May 10, 2026

Description

  • Uniform Grid Resizing: Enforces a fixed 768 x 768 resize resolution in preprocess_mm_data_qwen3_omni_for_training to guarantee identical patch counts and uniform tensor shapes across diverse batch examples, fully resolving grain.Batch stacking mismatches.
  • Architecture Alignment: Corrects the hardcoded preprocessing fallback to use the true model default patch_size_for_vit = 16 and reshapes flattened patch arrays back to unflattened 5D tensors (batch, channels, temporal, height, width), eliminating ValueError: not enough values to unpack at the vision encoder input.
  • Robust Padding & Offsets: Updates the single_image_offset calculation in input_pipeline_utils.py to safely use the num_images attribute, preventing ZeroDivisionError crashes on flattened patch counts. Adds an explicit padding bypass for Qwen3 Omni patch tensors to avoid false mask limit exceptions.

Tests

10-step E2E test run:

python -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/configs/post_train/sft-vision-chartqa.yml model_name=qwen3-omni-30b-a3b tokenizer_path=Qwen/Qwen3-Omni-30B-A3B-Instruct base_output_directory=gs://hengtaoguo-maxtext-logs/sft/qwen3-omni-30b-a3b per_device_batch_size=2 dataset_type=hf steps=10 max_target_length=1024 checkpoint_period=100 attention=dot_product enable_checkpointing=false hf_access_token=<your_token>
I0510 21:11:20.578492 140261636574336 metric_logger.py:196] completed step: 0, seconds: 35.570, TFLOP/s/device: 0.129, Tokens/s/device: 57.577, total_weights: 36, loss: 12.508, lm_loss: 12.508, perplexity: 270415.000, moe_lb_loss: 0.000
I0510 21:11:20.581142 140261636574336 metric_logger.py:281] To see full metrics 'tensorboard --logdir=gs://hengtaoguo-maxtext-logs/sft/qwen3-omni-30b-a3b/qwen3-omni-30b-a3b_2026-05-10-21-09/tensorboard/'
I0510 21:11:21.169645 140261636574336 metric_logger.py:196] completed step: 1, seconds: 0.373, TFLOP/s/device: 12.277, Tokens/s/device: 5492.840, total_weights: 43, loss: 12.423, lm_loss: 12.423, perplexity: 248505.000, moe_lb_loss: 0.000
I0510 21:11:21.254970 140261636574336 metric_logger.py:196] completed step: 2, seconds: 0.591, TFLOP/s/device: 7.741, Tokens/s/device: 3463.584, total_weights: 39, loss: 10.753, lm_loss: 10.753, perplexity: 46787.500, moe_lb_loss: 0.000
I0510 21:11:21.364491 140261636574336 metric_logger.py:196] completed step: 3, seconds: 0.042, TFLOP/s/device: 108.389, Tokens/s/device: 48495.181, total_weights: 37, loss: 9.743, lm_loss: 9.743, perplexity: 17027.156, moe_lb_loss: 0.000
I0510 21:11:21.473644 140261636574336 metric_logger.py:196] completed step: 4, seconds: 0.082, TFLOP/s/device: 55.869, Tokens/s/device: 24996.949, total_weights: 40, loss: 8.718, lm_loss: 8.718, perplexity: 6109.984, moe_lb_loss: 0.000
I0510 21:11:21.582713 140261636574336 metric_logger.py:196] completed step: 5, seconds: 0.110, TFLOP/s/device: 41.799, Tokens/s/device: 18701.659, total_weights: 40, loss: 8.500, lm_loss: 8.500, perplexity: 4913.453, moe_lb_loss: 0.000
I0510 21:11:21.691225 140261636574336 metric_logger.py:196] completed step: 6, seconds: 0.108, TFLOP/s/device: 42.365, Tokens/s/device: 18955.065, total_weights: 44, loss: 7.646, lm_loss: 7.646, perplexity: 2091.277, moe_lb_loss: 0.000
I0510 21:11:21.799779 140261636574336 metric_logger.py:196] completed step: 7, seconds: 0.109, TFLOP/s/device: 41.890, Tokens/s/device: 18742.221, total_weights: 41, loss: 7.797, lm_loss: 7.797, perplexity: 2434.350, moe_lb_loss: 0.000
I0510 21:11:21.908020 140261636574336 metric_logger.py:196] completed step: 8, seconds: 0.102, TFLOP/s/device: 45.077, Tokens/s/device: 20168.200, total_weights: 42, loss: 7.082, lm_loss: 7.082, perplexity: 1190.472, moe_lb_loss: 0.000
I0510 21:11:22.016094 140261636574336 metric_logger.py:196] completed step: 9, seconds: 0.109, TFLOP/s/device: 42.172, Tokens/s/device: 18868.620, total_weights: 40, loss: 7.796, lm_loss: 7.796, perplexity: 2430.939, moe_lb_loss: 0.000

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 10, 2026

Codecov Report

❌ Patch coverage is 0% with 26 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/multimodal/processor_qwen3_omni.py 0.00% 17 Missing ⚠️
src/maxtext/input_pipeline/input_pipeline_utils.py 0.00% 6 Missing ⚠️
src/maxtext/multimodal/processor.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant