Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/configs/models/gemma4-31b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# model config for gemma4-31b Dense

base_num_decoder_layers: 60
base_num_decoder_layers: 1
base_emb_dim: 5376
base_num_query_heads: 32
base_num_kv_heads: 16
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/models/qwen3-omni-30b-a3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ base_emb_dim: 2048
base_mlp_dim: 768
base_num_query_heads: 32
base_num_kv_heads: 4
base_num_decoder_layers: 48
base_num_decoder_layers: 1
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 152064
Expand Down
8 changes: 7 additions & 1 deletion src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,15 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -
if preprocessed_image.pixel_values is None:
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")

if self.config.model_name and self.config.model_name.startswith("qwen3-omni"):
return preprocessed_image

# Determine the maximum number of images/masks allowed.
image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image)
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]
num_images = getattr(preprocessed_image, "num_images", 0)
if num_images <= 0:
num_images = preprocessed_image.pixel_values.shape[0]
single_image_offset = image_offsets // num_images

# Reserve space for at least one text token.
max_num_items = (self.max_length - 1) // single_image_offset
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/multimodal/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def preprocess_image_for_training(image, model_name):
from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel

return preprocess_mm_data_llama4(image)
elif model_name in ["qwen3-omni-30b-a3b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel

return preprocess_mm_data_qwen3_omni_for_training(image)
else:
raise ValueError(f"Model {model_name} not supported for image preprocessing.")

Expand Down
67 changes: 49 additions & 18 deletions src/maxtext/multimodal/processor_qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def smart_resize(
return h_bar, w_bar


def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config):
def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config, force_resize=None):
"""Performs a bi-linear resize (with anti-aliasing) and normalizes the image."""
patch_size = config.patch_size_for_vit
merge_size = config.spatial_merge_size_for_vit
Expand All @@ -135,23 +135,27 @@ def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config):

for img in images_in:
pil_img = Image.fromarray(img)
# Qwen3-Omni performs one resize during fetch_image and another resize before patchify.
resized_height_1, resized_width_1 = smart_resize(
height=img.shape[0],
width=img.shape[1],
factor=IMAGE_FACTOR,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
pil_img = pil_img.resize((resized_width_1, resized_height_1))
resized_height_2, resized_width_2 = smart_resize(
height=resized_height_1,
width=resized_width_1,
factor=patch_size * merge_size,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
if force_resize is not None:
resized_height_2, resized_width_2 = force_resize
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
else:
# Qwen3-Omni performs one resize during fetch_image and another resize before patchify.
resized_height_1, resized_width_1 = smart_resize(
height=img.shape[0],
width=img.shape[1],
factor=IMAGE_FACTOR,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
pil_img = pil_img.resize((resized_width_1, resized_height_1))
resized_height_2, resized_width_2 = smart_resize(
height=resized_height_1,
width=resized_width_1,
factor=patch_size * merge_size,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
resized_img_np = np.array(resized_img_pil).astype(np.float32)

img_np = mm_utils.normalize_images(resized_img_np, mean=IMAGE_MEAN, std=IMAGE_STD)
Expand Down Expand Up @@ -474,6 +478,33 @@ def pre_process_audio_qwen3_omni(audio_array):
return audio_features, audio_features_mask


def preprocess_mm_data_qwen3_omni_for_training(images):
"""Preprocesses image(s) for Qwen3-Omni SFT training using default model constants."""

class _DefaultConfig:
patch_size_for_vit = 16
spatial_merge_size_for_vit = 2
temporal_patch_size_for_vit = QWEN3_TEMPORAL_PATCH_SIZE

images_in = [images] if isinstance(images, np.ndarray) else images
pixel_values, pixel_grid_thw = pre_process_qwen3_image(images_in, _DefaultConfig(), force_resize=(768, 768))
pixel_values = np.reshape(
pixel_values,
(
len(images_in),
3, # num_channels_for_vit
_DefaultConfig.temporal_patch_size_for_vit * pixel_grid_thw[0, 0],
_DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 1],
_DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 2],
),
)
return Qwen3OmniPreprocessorOutput(
num_images=len(images_in),
pixel_values=pixel_values,
pixel_grid_thw=pixel_grid_thw,
)


def preprocess_mm_data_qwen3_omni(config):
"""Placeholder for multimodal data preprocessing."""
processor_outputs = Qwen3OmniPreprocessorOutput()
Expand Down
Loading