Skip to content

Add support for Qwen3.5 dense models#1123

Merged
Tcc0403 merged 1 commit intolinkedin:mainfrom
ruilin-gif:add-qwen3.5-dense-support
Mar 12, 2026
Merged

Add support for Qwen3.5 dense models#1123
Tcc0403 merged 1 commit intolinkedin:mainfrom
ruilin-gif:add-qwen3.5-dense-support

Conversation

@ruilin-gif
Copy link
Copy Markdown
Contributor

Summary

Adds Liger Kernel support for the Qwen3.5 dense model family (0.8B / 2B / 4B / 9B / 27B).

Closes #1119

Qwen3.5 uses the same hybrid GDN + full-attention architecture as Qwen3 Next but is a distinct model type (qwen3_5) in transformers v5.2+. This PR adds a dedicated apply_liger_kernel_to_qwen3_5 entry point so auto-patching works correctly.

Patched kernels:

  • RMSNorm — Gemma-style offset RMSNorm (LigerRMSNormForQwen3Next, offset=1.0, casting_mode="gemma")
  • SwiGLU — Fused SwiGLU MLP (LigerQwen3MoeSwiGLUMLP, matching Qwen3.5's __init__ signature)
  • Fused Linear Cross Entropy — avoids materializing the full logits tensor (vocab_size=248,320)
  • RoPE — disabled (raises NotImplementedError); only 1 in 4 layers uses RoPE via full attention

Details

New file:

  • src/liger_kernel/transformers/model/qwen3_5.pylce_forward for Qwen3_5ForCausalLM

Modified files:

  • monkey_patch.pyapply_liger_kernel_to_qwen3_5 with class-level and instance-level patching; registered as qwen3_5 in MODEL_TYPE_TO_APPLY_LIGER_FN
  • __init__.py — export
  • test/utils.pyrevert_liger_kernel_to_qwen3_5
  • test_monkey_patch.py — instance patching test
  • bf16/test_mini_models.py — convergence test
  • bf16/test_mini_models_with_logits.py — convergence with logits
  • fp32/test_mini_models.py — convergence (skipped; ChunkGatedDeltaRuleFunction doesn't support float32)
  • fp32/test_mini_models_with_logits.py — convergence with logits (skipped, same reason)
  • README.md — added Qwen3.5 to supported models table

Validation on real models (Qwen3.5-0.8B through 27B):

Model Logits cosine sim Training loss rel diff
0.8B > 0.999 0.32%
2B > 0.999 0.20%
4B > 0.999 0.28%
9B > 0.999 0.39%
27B > 0.999 0.41%

Testing Done

  • Hardware Type: NVIDIA H200
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
test_apply_liger_kernel_to_instance_for_qwen3_5         PASSED
test_mini_model[mini_qwen3_5 bf16]                      PASSED  (6.47s)
test_mini_model[mini_qwen3_5 bf16 with_logits]          PASSED  (7.23s)
test_mini_model[mini_qwen3_5 fp32]                      SKIPPED (ChunkGatedDeltaRuleFunction no fp32)
test_mini_model[mini_qwen3_5 fp32 with_logits]          SKIPPED (same)

f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM. Got: {type(model)}"
)
else:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the architecture name has been changed to Qwen3_5ForConditionalGeneration (see config.json), should it also be patched?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider Qwen3_5ForConditionalGeneration

def apply_liger_kernel_to_qwen3_5(
    rope: bool = False,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 dense models.

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
            Not yet supported for Qwen3.5 due to hybrid attention (Gated DeltaNet + Gated Attention).
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.qwen3_5 import modeling_qwen3_5
    from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
    from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel
    
    try:
        from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration
    except ImportError:
        Qwen3_5ForConditionalGeneration = None

    from liger_kernel.transformers.model.qwen3_5 import lce_forward as qwen3_5_lce_forward
    from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
    from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
    from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module, _patch_swiglu_module

    if rope:
        raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5 models.")
        
    if rms_norm:
        modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3Next
        
    if cross_entropy:
        from transformers.loss.loss_utils import nn
        from liger_kernel.transformers.cross_entropy import liger_cross_entropy
        nn.functional.cross_entropy = liger_cross_entropy
        
    if fused_linear_cross_entropy:
        valid_classes = (Qwen3_5ForCausalLM,)
        if Qwen3_5ForConditionalGeneration is not None:
            valid_classes += (Qwen3_5ForConditionalGeneration,)
            
        if model is not None:
            if isinstance(model, valid_classes):
                model.forward = MethodType(qwen3_5_lce_forward, model)
            else:
                raise TypeError(
                    f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM or Qwen3_5ForConditionalGeneration. Got: {type(model)}"
                )
        else:
            modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
            if Qwen3_5ForConditionalGeneration is not None:
                modeling_qwen3_5.Qwen3_5ForConditionalGeneration.forward = qwen3_5_lce_forward
                
    if swiglu:
        modeling_qwen3_5.Qwen3_5MLP = LigerQwen3MoeSwiGLUMLP

    if model is not None:
        if isinstance(model, (Qwen3_5ForCausalLM, Qwen3_5TextModel)):
            base_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model)
        elif Qwen3_5ForConditionalGeneration is not None and isinstance(model, Qwen3_5ForConditionalGeneration):
            base_model = model.model.language_model
        else:
            raise TypeError(
                f"Unsupported qwen3_5 model type. Got: {type(model)}"
            )

        _patch_rms_norm_module_for_qwen3_5 = partial(
            _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
        )

        if rms_norm:
            _patch_rms_norm_module_for_qwen3_5(base_model.norm)

        for decoder_layer in base_model.layers:
            if rms_norm:
                _patch_rms_norm_module_for_qwen3_5(decoder_layer.input_layernorm)
                _patch_rms_norm_module_for_qwen3_5(decoder_layer.post_attention_layernorm)

            if swiglu:
                _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3_5ForCausalLM and Qwen3_5ForConditionalGeneration's forward methods are slightly different, it would require an extra lce_forward for multimodal.

Qwen3_5ForCausalLM

https://github.com/huggingface/transformers/blob/adc2f16bf1824f7b57c790b4cf3bc48f95ecec69/src/transformers/models/qwen3_5/modeling_qwen3_5.py#L1823-L1831

        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            **kwargs,
        )

Qwen3_5ForConditionalGeneration

https://github.com/huggingface/transformers/blob/adc2f16bf1824f7b57c790b4cf3bc48f95ecec69/src/transformers/models/qwen3_5/modeling_qwen3_5.py#L2000-L2012

        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            mm_token_type_ids=mm_token_type_ids,
            **kwargs,
        )

@MichaelWalker-git
Copy link
Copy Markdown

MichaelWalker-git commented Mar 11, 2026

Hi! I've been using this PR to train Qwen3.5-9B with fused cross-entropy on 8× A100 40GB and found two issues that may affect other users:

1. pip install git+... produces an empty wheel (5.7KB vs 276KB)

setup.py line 118 has packages=["liger_kernel"] which only includes the top-level package. All subpackages are excluded. This appears to be a pre-existing issue in the upstream setup.py as well. Fix: find_packages(where="src").

2. qwen3_5.py is incompatible with released liger-kernel 0.7.0

unpack_cross_entropy_result() returns 4 values in this branch (with predicted_tokens) but 3 in official 0.7.0. Users who install 0.7.0 from PyPI and overlay the Qwen3.5-specific files get ValueError: not enough values to unpack (expected 4, got 3). Similarly, LigerCausalLMOutputWithPast in 0.7.0 lacks the predicted_tokens field.

I have submitted a fix to the fork: ruilin-gif#1

My workaround: install official 0.7.0, overlay monkey_patch.py + __init__.py from this branch, and use a patched qwen3_5.py with *rest unpacking to handle both 3-value and 4-value API gracefully.

Thanks for adding Qwen3.5 support — the fused cross-entropy is critical for training at 16384+ token sequences!

@nideyongbao
Copy link
Copy Markdown

nideyongbao commented Mar 12, 2026

Will vision_model be supported? like apply_liger_kernel_to_qwen2_5_vl?

dense:

(visual): Qwen3_5VisionModel(
        (patch_embed): Qwen3_5VisionPatchEmbed(
          (proj): Conv3d(3, 1152, kernel_size=(2, 16, 16), stride=(2, 16, 16))
        )
        (pos_embed): Embedding(2304, 1152)
        (rotary_pos_emb): Qwen3_5VisionRotaryEmbedding()
        (blocks): ModuleList(
          (0-26): 27 x Qwen3_5VisionBlock(
            (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (attn): Qwen3_5VisionAttention(
              (qkv): Linear(in_features=1152, out_features=3456, bias=True)
              (proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (mlp): Qwen3_5VisionMLP(
              (linear_fc1): Linear(in_features=1152, out_features=4304, bias=True)
              (linear_fc2): Linear(in_features=4304, out_features=1152, bias=True)
              (act_fn): GELUTanh()
            )
          )
        )
        (merger): Qwen3_5VisionPatchMerger(
          (norm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (linear_fc1): Linear(in_features=4608, out_features=4608, bias=True)
          (act_fn): GELU(approximate='none')
          (linear_fc2): Linear(in_features=4608, out_features=4096, bias=True)
        )
      )

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, LGTM! Would you like to support multimodal patch as a follow-up?

f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM. Got: {type(model)}"
)
else:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3_5ForCausalLM and Qwen3_5ForConditionalGeneration's forward methods are slightly different, it would require an extra lce_forward for multimodal.

Qwen3_5ForCausalLM

https://github.com/huggingface/transformers/blob/adc2f16bf1824f7b57c790b4cf3bc48f95ecec69/src/transformers/models/qwen3_5/modeling_qwen3_5.py#L1823-L1831

        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            **kwargs,
        )

Qwen3_5ForConditionalGeneration

https://github.com/huggingface/transformers/blob/adc2f16bf1824f7b57c790b4cf3bc48f95ecec69/src/transformers/models/qwen3_5/modeling_qwen3_5.py#L2000-L2012

        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            mm_token_type_ids=mm_token_type_ids,
            **kwargs,
        )

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Mar 12, 2026

  1. qwen3_5.py is incompatible with released liger-kernel 0.7.0

@MichaelWalker-git Try nightly build if you are in a rush

$ pip install liger-kernel-nightly

@Tcc0403 Tcc0403 enabled auto-merge March 12, 2026 22:10
@Tcc0403 Tcc0403 added this pull request to the merge queue Mar 12, 2026
Merged via the queue into linkedin:main with commit a01a002 Mar 12, 2026
5 of 7 checks passed
github-merge-queue Bot pushed a commit that referenced this pull request Mar 23, 2026
…al. (#1150)

## Summary
Add support for Qwen3.5 multimodal patches as follow-up for #1123.
Implement extra `lce_forward` for `Qwen3_5ForConditionalGeneration` and
add corresponding tests.

### Details

Modified files:
- `model/qwen3_5.py`: add `lce_forward_for_multimodal` for
`Qwen3_5ForConditionalGeneration`.
- `model/output_classes.py`: add `LigerQwen3_5CausalLMOutputWithPast` as
the return type for qwen3_5.
- `test/utils.py`: `revert_liger_kernel_to_qwen3_5` to support
`conditional_generation` for qwen3_5.
- `test_monkey_patch.py`: add
`test_apply_liger_kernel_to_instance_for_qwen3_5_for_conditional_generation`.
- `bf16/test_mini_models_multimodal.py`: add
`MINI_MODEL_SETUPS["mini_qwen3_5"]` to convergence test.
- `fp32/test_mini_models_multimodal.py`: skipped

## Testing Done

- Hardware Type: NVIDIA A100
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen 3.5 Dense Model Support

5 participants