Skip to content

Conversation

@a4lg
Copy link
Contributor

@a4lg a4lg commented Dec 9, 2025

Purpose

This is a follow-up after #30116 (which unblocks loading Qwen2/3 MoE + GGUF models).

This commit fixes two issues preventing to load Qwen2 MoE + GGUF models:

  1. Initialization of Qwen2MoeModel.embed_tokens is fixed to use correct prefix and respect quantization settings.
  2. Added GGUF-specific compatibility layer for Qwen2MoeSparseMoeBlock.shared_expert_gate (GGUF: 1D tensor (n), HF/vLLM: 2D tensor (1, n)).

The latter was implemented in HF Transformers as a part of Qwen2MoeTensorProcessor but since vLLM's weight loader directly uses gguf library without such compatibility layer, we need to implement the equivalent.

Test Plan

You may download a quantized Qwen1.5-MoE-A2.7B-Chat model such as tensorblock/Qwen1.5-MoE-A2.7B-Chat-GGUF and run like:

vllm serve \
	Qwen1.5-MoE-A2.7B-Chat-Q4_K_S.gguf \
	--tokenizer Qwen/Qwen1.5-MoE-A2.7B-Chat \
	--max-model-len 4096 \
	--gpu-memory-utilization 0.5

WARNING: Qwen2-57B-A14B-Instruct with the same architecture won't work (alone) because a bug in HF Transformers prevents loading non-default parameters.
See #30116 for details (fixed in upcoming V5 by huggingface/transformers#42650 but not sure whether V4 (which vLLM currently depends on) will be fixed the same).

Test Result

Before this PR, either one of the errors will occur (each corresponds to a fix above):

  1. KeyError: 'embed_tokens.qweight_type' (you should normally see this) or
  2. AssertionError: Attempted to load weight (torch.Size([2048])) into parameter (torch.Size([1, 2048]))
    (2048 is the value on Qwen1.5-MoE-A2.7B{,-Chat}).

After this PR is applied, these errors will go away and you should be able to use Qwen2 MoE + GGUF models.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

This commit fixes two issues preventing to load Qwen2 MoE + GGUF models:

1.  Initialization of `Qwen2MoeModel.embed_tokens` is fixed to use
    correct prefix and respect quantization settings.
2.  Added GGUF-specific compatibility layer for
    `Qwen2MoeSparseMoeBlock.shared_expert_gate`
    (GGUF: 1D tensor `(n)`, HF/vLLM: 2D tensor `(1, n)`).

The latter was implemented in HF Transformers as a part of
`Qwen2MoeTensorProcessor` but since vLLM's weight loader directly uses
`gguf` library without such compatibility layer, we need to implement
the equivalent.

cf. <https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/modeling_gguf_pytorch_utils.py#L110-L113>

Beware that, due to a bug in HF Transformers, Qwen2-57B-A14B-Instruct
(with the same architecture but some non-default parameters) will not work
with this fix alone.

Signed-off-by: Tsukasa OI <floss_llm@irq.a4lg.com>
@a4lg a4lg requested a review from sighingnow as a code owner December 9, 2025 05:39
@mergify mergify bot added the qwen Related to Qwen models label Dec 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fixes to enable GGUF support for Qwen2 MoE models. The changes include correctly initializing embed_tokens to respect quantization settings and adding a compatibility layer to handle a shape mismatch for the shared_expert_gate weight in GGUF models.

The changes are correct and address the described issues. I have one suggestion to make the weight loading logic for shared_expert_gate more robust by adding more explicit checks against the model's parameter shape, which will prevent potential issues with other formats in the future.

@Isotr0py Isotr0py self-assigned this Dec 9, 2025
@Isotr0py Isotr0py enabled auto-merge (squash) December 9, 2025 15:54
@Isotr0py Isotr0py disabled auto-merge December 9, 2025 15:56
@Isotr0py Isotr0py enabled auto-merge (squash) December 9, 2025 16:02
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 9, 2025
@Isotr0py Isotr0py merged commit 73a484c into vllm-project:main Dec 9, 2025
56 of 57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants