Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,188 @@ def __init__(self, **kwargs):
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)


qwen3_5_397b_a17b_dict = {
"architectures": [
"Qwen3_5MoeForConditionalGeneration"
],
"image_token_id": 248056,
"model_type": "qwen3_5_moe",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_output_gate": True,
"dtype": "bfloat16",
"eos_token_id": 248044,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"layer_types": [
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention"
],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 64,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_5_moe_text",
"moe_intermediate_size": 1024,
"mtp_num_hidden_layers": 1,
"mtp_use_dedicated_embeddings": False,
"num_attention_heads": 32,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 60,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 1024,
"use_cache": True,
"vocab_size": 248320,
"mamba_ssm_dtype": "float32",
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [
11,
11,
10
],
"rope_type": "default",
"rope_theta": 10000000,
"partial_rotary_factor": 0.25
}
},
"tie_word_embeddings": False,
"transformers_version": "4.57.0.dev0",
"video_token_id": 248057,
"vision_config": {
"deepstack_visual_indexes": [],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_5_moe",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2
},
"vision_end_token_id": 248054,
"vision_start_token_id": 248053
}
qwen3_5_397b_a17b_config = transformers.Qwen3_5MoeConfig(**qwen3_5_397b_a17b_dict)


qwen3_5_35b_a3b_dict = {
"architectures": [
"Qwen3_5MoeForConditionalGeneration"
],
"image_token_id": 248056,
"model_type": "qwen3_5_moe",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_output_gate": True,
"dtype": "bfloat16",
"eos_token_id": 248044,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"layer_types": [
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention"
],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_5_moe_text",
"moe_intermediate_size": 512,
"mtp_num_hidden_layers": 1,
"mtp_use_dedicated_embeddings": False,
"num_attention_heads": 16,
"num_experts": 256,
"num_experts_per_tok": 8,
"num_hidden_layers": 40,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"use_cache": True,
"vocab_size": 248320,
"mamba_ssm_dtype": "float32",
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [
11,
11,
10
],
"rope_type": "default",
"rope_theta": 10000000,
"partial_rotary_factor": 0.25
}
},
"tie_word_embeddings": False,
"transformers_version": "4.57.0.dev0",
"video_token_id": 248057,
"vision_config": {
"deepstack_visual_indexes": [],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_5_moe",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2
},
"vision_end_token_id": 248054,
"vision_start_token_id": 248053
}

qwen3_5_35b_a3b_config = transformers.Qwen3_5MoeConfig(**qwen3_5_35b_a3b_dict)


# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
mixtral_8x7b_dict = {
"architectures": ["MixtralForCausalLM"],
Expand Down Expand Up @@ -1214,6 +1396,8 @@ def __init__(self, **kwargs):
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
"qwen3.5-397b-a17b": qwen3_5_397b_a17b_config,
"qwen3.5-35b-a3b": qwen3_5_35b_a3b_config,
"mixtral-8x7b": mixtral_8x7b_config,
"mixtral-8x22b": mixtral_8x22b_config,
"olmo3-7b": olmo3_7b_config,
Expand Down
96 changes: 96 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,102 @@ def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
)


def QWEN3_5_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3.5 weights path and their shape."""
# --- Extract Core Config Values ---
hidden_size = config["hidden_size"]
num_hidden_layers = config["num_hidden_layers"]
vocab_size = config["vocab_size"]
num_attention_heads = config["num_attention_heads"]
num_key_value_heads = config["num_key_value_heads"]
num_experts = config["num_experts"]
head_dim = config["head_dim"]
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
linear_key_head_dim = config["linear_key_head_dim"]
linear_value_head_dim = config["linear_value_head_dim"]
linear_num_key_heads = config["linear_num_key_heads"]
linear_num_value_heads = config["linear_num_value_heads"]
moe_intermediate_size = config["moe_intermediate_size"]
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
cycle_interval = config["full_attention_interval"]

# --- Calculated Values ---
q_dim = num_attention_heads * head_dim
kv_dim = num_key_value_heads * head_dim

linear_k_dim = linear_num_key_heads * linear_key_head_dim
linear_v_dim = linear_num_value_heads * linear_value_head_dim
conv_dim = 2 * linear_k_dim + linear_v_dim

# --- Initialize Mapping ---
mapping = {
"model.language_model.embed_tokens.weight": [vocab_size, hidden_size],
"model.language_model.norm.weight": [hidden_size],
"lm_head.weight": [vocab_size, hidden_size],
}

for layer_idx in range(num_hidden_layers):
layer_prefix = f"model.language_model.layers.{layer_idx}"

# Standard Layer Norms
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]

is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0

if is_full_attention_layer:
# Full Attention Block
mapping.update(
{
f"{layer_prefix}.self_attn.q_proj.weight": [q_dim, hidden_size],
f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size],
f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size],
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim],
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
}
)
else:
# Linear Attention (GDN) Block - Updated with Unfused weights
mapping.update(
{
f"{layer_prefix}.linear_attn.in_proj_qkv.weight": [conv_dim, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_z.weight": [linear_v_dim, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_b.weight": [linear_num_value_heads, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_a.weight": [linear_num_value_heads, hidden_size],
f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim],
f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads],
f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads],
f"{layer_prefix}.linear_attn.norm.weight": [linear_value_head_dim],
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim],
}
)

# --- MLP Logic (MoE + Shared) ---
mapping.update(
{
# Router
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],
# Shared Experts (SwiGLU - Separate Weights)
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],
# Shared Expert Gate (learned scaling factor)
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
}
)

# --- Vectorized & Fused Routed Experts (No loop, no .weight suffix) ---
mapping.update(
{
f"{layer_prefix}.mlp.experts.gate_up_proj": [num_experts, 2 * moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.down_proj": [num_experts, hidden_size, moe_intermediate_size],
}
)

return mapping


def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
# --- Extract Core Config Values ---
Expand Down
Loading
Loading