Add support for Qwen3.5 dense models#1123
Conversation
Made-with: Cursor
| f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM. Got: {type(model)}" | ||
| ) | ||
| else: | ||
| modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward |
There was a problem hiding this comment.
Since the architecture name has been changed to Qwen3_5ForConditionalGeneration (see config.json), should it also be patched?
There was a problem hiding this comment.
consider Qwen3_5ForConditionalGeneration
def apply_liger_kernel_to_qwen3_5(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 dense models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
Not yet supported for Qwen3.5 due to hybrid attention (Gated DeltaNet + Gated Attention).
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_5 import modeling_qwen3_5
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel
try:
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration
except ImportError:
Qwen3_5ForConditionalGeneration = None
from liger_kernel.transformers.model.qwen3_5 import lce_forward as qwen3_5_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module, _patch_swiglu_module
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5 models.")
if rms_norm:
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3Next
if cross_entropy:
from transformers.loss.loss_utils import nn
from liger_kernel.transformers.cross_entropy import liger_cross_entropy
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
valid_classes = (Qwen3_5ForCausalLM,)
if Qwen3_5ForConditionalGeneration is not None:
valid_classes += (Qwen3_5ForConditionalGeneration,)
if model is not None:
if isinstance(model, valid_classes):
model.forward = MethodType(qwen3_5_lce_forward, model)
else:
raise TypeError(
f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM or Qwen3_5ForConditionalGeneration. Got: {type(model)}"
)
else:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
if Qwen3_5ForConditionalGeneration is not None:
modeling_qwen3_5.Qwen3_5ForConditionalGeneration.forward = qwen3_5_lce_forward
if swiglu:
modeling_qwen3_5.Qwen3_5MLP = LigerQwen3MoeSwiGLUMLP
if model is not None:
if isinstance(model, (Qwen3_5ForCausalLM, Qwen3_5TextModel)):
base_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model)
elif Qwen3_5ForConditionalGeneration is not None and isinstance(model, Qwen3_5ForConditionalGeneration):
base_model = model.model.language_model
else:
raise TypeError(
f"Unsupported qwen3_5 model type. Got: {type(model)}"
)
_patch_rms_norm_module_for_qwen3_5 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rms_norm:
_patch_rms_norm_module_for_qwen3_5(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module_for_qwen3_5(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_5(decoder_layer.post_attention_layernorm)
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
There was a problem hiding this comment.
Qwen3_5ForCausalLM and Qwen3_5ForConditionalGeneration's forward methods are slightly different, it would require an extra lce_forward for multimodal.
Qwen3_5ForCausalLM
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
Qwen3_5ForConditionalGeneration
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
mm_token_type_ids=mm_token_type_ids,
**kwargs,
)
|
Hi! I've been using this PR to train Qwen3.5-9B with fused cross-entropy on 8× A100 40GB and found two issues that may affect other users: 1.
2.
I have submitted a fix to the fork: ruilin-gif#1 My workaround: install official 0.7.0, overlay Thanks for adding Qwen3.5 support — the fused cross-entropy is critical for training at 16384+ token sequences! |
|
Will vision_model be supported? like apply_liger_kernel_to_qwen2_5_vl? dense: |
Tcc0403
left a comment
There was a problem hiding this comment.
Thanks for the contribution, LGTM! Would you like to support multimodal patch as a follow-up?
| f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM. Got: {type(model)}" | ||
| ) | ||
| else: | ||
| modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward |
There was a problem hiding this comment.
Qwen3_5ForCausalLM and Qwen3_5ForConditionalGeneration's forward methods are slightly different, it would require an extra lce_forward for multimodal.
Qwen3_5ForCausalLM
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
Qwen3_5ForConditionalGeneration
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
mm_token_type_ids=mm_token_type_ids,
**kwargs,
)
@MichaelWalker-git Try nightly build if you are in a rush |
…al. (#1150) ## Summary Add support for Qwen3.5 multimodal patches as follow-up for #1123. Implement extra `lce_forward` for `Qwen3_5ForConditionalGeneration` and add corresponding tests. ### Details Modified files: - `model/qwen3_5.py`: add `lce_forward_for_multimodal` for `Qwen3_5ForConditionalGeneration`. - `model/output_classes.py`: add `LigerQwen3_5CausalLMOutputWithPast` as the return type for qwen3_5. - `test/utils.py`: `revert_liger_kernel_to_qwen3_5` to support `conditional_generation` for qwen3_5. - `test_monkey_patch.py`: add `test_apply_liger_kernel_to_instance_for_qwen3_5_for_conditional_generation`. - `bf16/test_mini_models_multimodal.py`: add `MINI_MODEL_SETUPS["mini_qwen3_5"]` to convergence test. - `fp32/test_mini_models_multimodal.py`: skipped ## Testing Done - Hardware Type: NVIDIA A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
Summary
Adds Liger Kernel support for the Qwen3.5 dense model family (0.8B / 2B / 4B / 9B / 27B).
Closes #1119
Qwen3.5 uses the same hybrid GDN + full-attention architecture as Qwen3 Next but is a distinct model type (
qwen3_5) in transformers v5.2+. This PR adds a dedicatedapply_liger_kernel_to_qwen3_5entry point so auto-patching works correctly.Patched kernels:
LigerRMSNormForQwen3Next,offset=1.0, casting_mode="gemma")LigerQwen3MoeSwiGLUMLP, matching Qwen3.5's__init__signature)NotImplementedError); only 1 in 4 layers uses RoPE via full attentionDetails
New file:
src/liger_kernel/transformers/model/qwen3_5.py—lce_forwardforQwen3_5ForCausalLMModified files:
monkey_patch.py—apply_liger_kernel_to_qwen3_5with class-level and instance-level patching; registered asqwen3_5inMODEL_TYPE_TO_APPLY_LIGER_FN__init__.py— exporttest/utils.py—revert_liger_kernel_to_qwen3_5test_monkey_patch.py— instance patching testbf16/test_mini_models.py— convergence testbf16/test_mini_models_with_logits.py— convergence with logitsfp32/test_mini_models.py— convergence (skipped;ChunkGatedDeltaRuleFunctiondoesn't support float32)fp32/test_mini_models_with_logits.py— convergence with logits (skipped, same reason)README.md— added Qwen3.5 to supported models tableValidation on real models (Qwen3.5-0.8B through 27B):
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence