Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo
from lightllm.utils.log_utils import init_logger
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager
from typing import Tuple
from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
Expand Down Expand Up @@ -249,16 +250,13 @@ def gdn_forward(
assert isinstance(infer_state.mem_manager, Qwen3NextMemManager)

input = input.view(-1, self.embed_dim_)
conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)

mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill)

if is_prefill:
core_attn_out = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight
)
core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight)
else:
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=True)
conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)
core_attn_out = self._gdn_decode_kernel(
mixed_qkv,
conv_states,
Expand All @@ -277,6 +275,54 @@ def gdn_forward(
output = layer_weight.linear_out_proj.mm(core_attn_out)
return output

def _gdn_prefill_wrapper_run(
self,
mixed_qkvzba: torch.Tensor,
infer_state: Qwen3NextInferStateInfo,
layer_weight: Qwen3NextTransformerLayerWeight,
) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.cuda.is_current_stream_capturing():
mixed_qkvzba = mixed_qkvzba.contiguous()
_mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba)
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

# _gdn_prefill_kernel returns the pre-projection value stream. Its
# logical size is num_tokens * local value heads * value head dim.
# We avoid a dry-run because FlashQLA may do host-side syncs while
# preparing varlen chunk metadata, which is illegal during capture.
num_tokens = mixed_qkvzba.shape[0]
o_shape = (num_tokens, self.tp_num_v_heads, self.head_v_dim)
o_dtype = mixed_qkvzba.dtype
o_device = mixed_qkvzba.device
z_shape = o_shape

infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
_o = tensor_to_no_ref_tensor(o)
z = torch.empty(z_shape, dtype=o_dtype, device=o_device)
_z = tensor_to_no_ref_tensor(z)

def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False)
_z.copy_(tmp_z)
tmp_o = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
)
tmp_o = tmp_o.view(_o.shape)
_o.copy_(tmp_o)
return

infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph)
return o, z

conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=False)
core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight)
return core_attn_out, z

def _split_qkvzba(self, mixed_qkvzba, is_decode=False):
qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim
z_end = qkv_dim + self.tp_value_dim
Expand Down
49 changes: 49 additions & 0 deletions lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# ruff: noqa: E501
import torch
from einops import rearrange
import functools
import os
from lightllm.utils.log_utils import init_logger

from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
Expand All @@ -19,6 +22,36 @@
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd

logger = init_logger(__name__)


@functools.lru_cache(maxsize=1)
def _flashqla_chunk_gated_delta_rule():
if os.environ.get("LIGHTLLM_DISABLE_FLASHQLA", "0").lower() in ["1", "true", "yes"]:
return None
try:
import flash_qla
except ImportError:
return None
if not torch.cuda.is_available():
return None
if torch.cuda.get_device_capability() < (9, 0):
return None
tv = torch.__version__.split("+")[0].split(".")
if (int(tv[0]), int(tv[1])) < (2, 8):
return None
cv = torch.version.cuda
if cv is None:
return None
cv_parts = cv.split(".")
if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8):
return None
Comment on lines +40 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version requirements for PyTorch (2.8) and CUDA (12.8) appear to be typos, as these versions are either not yet released or do not exist (CUDA 12.8). This will cause the FlashQLA backend to be disabled on all current environments. Additionally, the parsing logic is fragile and may raise IndexError or ValueError depending on the version string format (e.g., if it contains non-numeric suffixes like rc1).

    try:
        tv = torch.__version__.split("+")[0].split(".")
        if len(tv) < 2 or (int(tv[0]), int(tv[1])) < (2, 4):
            return None
        cv = torch.version.cuda
        if cv is None:
            return None
        cv_parts = cv.split(".")
        if len(cv_parts) < 2 or (int(cv_parts[0]), int(cv_parts[1])) < (12, 1):
            return None
    except (ValueError, IndexError):
        return None

logger.info(
"qwen3next chunk_gated_delta_rule: using FlashQLA backend (flash_qla.chunk_gated_delta_rule); "
"set LIGHTLLM_DISABLE_FLASHQLA=1 to fall back to the FLA Triton kernels."
)
return flash_qla.chunk_gated_delta_rule


def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
Expand Down Expand Up @@ -183,6 +216,22 @@ def chunk_gated_delta_rule(
cu_seqlens=cu_seqlens
)
"""
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
Comment on lines +221 to +233
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If scale is None, it is passed directly to flashqla_fn. The fallback Triton path explicitly calculates scale as k.shape[-1] ** -0.5. To ensure consistency and avoid potential issues if the flash_qla library does not handle None defaults, the scale should be explicitly provided.

Suggested change
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale if scale is not None else k.shape[-1] ** -0.5,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)


assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
Expand Down
Loading
Loading