Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
if isinstance(mt_value, int):
mt_value = mt_value * 2

# Special handling for Qwen3-MoE: hf.intermediate_size is the aggregated dense MLP dim, but mt.mlp_dim is dim per expert
if "qwen3" in max_config.model_name and getattr(max_config, "num_experts", 0) > 1 and hf_attr == "intermediate_size":
mt_value = mt_value * getattr(max_config, "num_experts_per_tok", 1)

# Handle vocab size padding
if hf_attr == "vocab_size" and isinstance(mt_value, int) and isinstance(hf_value, int):
# MaxText often pads vocab size to a multiple of 128 or 256 for TPU efficiency
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def _eager_getter(key):
"--eager_load_method",
type=str,
required=False,
default="transformers",
default="safetensors",
choices=["transformers", "safetensors"],
help="Backend to use for eager loading: `transformers_class.from_pretrained` or `safetensors.safe_open` with pt",
)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def QWEN_HF_WEIGHTS_TO_SHAPE(config):
}

# Determine if the model is MoE based on config keys
num_experts = config.get("num_experts", 0)
num_experts = config.get("num_experts", config.get("num_local_experts", 0))

for layer_idx in range(num_hidden_layers):
layer_prefix = f"model.layers.{layer_idx}"
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
or scanned with expert stacking (nested list of strings).
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
num_experts = config.get("num_experts", config.get("num_local_experts", 0))

mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
Expand Down Expand Up @@ -753,7 +753,7 @@ def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False,
transformation functions.
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
num_experts = config.get("num_experts", config.get("num_local_experts", 0))

def pad_embedding_layer(input_tensor, target_shape):
"""Pads or truncates embedding layer to match target vocab size."""
Expand Down
Loading