[Model][Quantization] Fix / Add GGUF support for Qwen2 MoE models #30307
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Purpose
This is a follow-up after #30116 (which unblocks loading Qwen2/3 MoE + GGUF models).
This commit fixes two issues preventing to load Qwen2 MoE + GGUF models:
Qwen2MoeModel.embed_tokensis fixed to use correct prefix and respect quantization settings.Qwen2MoeSparseMoeBlock.shared_expert_gate(GGUF: 1D tensor(n), HF/vLLM: 2D tensor(1, n)).The latter was implemented in HF Transformers as a part of
Qwen2MoeTensorProcessorbut since vLLM's weight loader directly usesgguflibrary without such compatibility layer, we need to implement the equivalent.Test Plan
You may download a quantized Qwen1.5-MoE-A2.7B-Chat model such as tensorblock/Qwen1.5-MoE-A2.7B-Chat-GGUF and run like:
WARNING: Qwen2-57B-A14B-Instruct with the same architecture won't work (alone) because a bug in HF Transformers prevents loading non-default parameters.
See #30116 for details (fixed in upcoming V5 by huggingface/transformers#42650 but not sure whether V4 (which vLLM currently depends on) will be fixed the same).
Test Result
Before this PR, either one of the errors will occur (each corresponds to a fix above):
KeyError: 'embed_tokens.qweight_type'(you should normally see this) orAssertionError: Attempted to load weight (torch.Size([2048])) into parameter (torch.Size([1, 2048]))(2048 is the value on
Qwen1.5-MoE-A2.7B{,-Chat}).After this PR is applied, these errors will go away and you should be able to use Qwen2 MoE + GGUF models.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.