Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo
from lightllm.utils.log_utils import init_logger
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager
from typing import Tuple
from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
Expand Down Expand Up @@ -249,16 +250,13 @@ def gdn_forward(
assert isinstance(infer_state.mem_manager, Qwen3NextMemManager)

input = input.view(-1, self.embed_dim_)
conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)

mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill)

if is_prefill:
core_attn_out = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight
)
core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight)
else:
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba)
conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)
core_attn_out = self._gdn_decode_kernel(
mixed_qkv,
conv_states,
Expand All @@ -277,7 +275,55 @@ def gdn_forward(
output = layer_weight.linear_out_proj.mm(core_attn_out)
return output

def _split_qkvzba(self, mixed_qkvzba, is_decode=False):
def _gdn_prefill_wrapper_run(
self,
mixed_qkvzba: torch.Tensor,
infer_state: Qwen3NextInferStateInfo,
layer_weight: Qwen3NextTransformerLayerWeight,
) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.cuda.is_current_stream_capturing():
mixed_qkvzba = mixed_qkvzba.contiguous()
_mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba)
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

# _gdn_prefill_kernel returns the pre-projection value stream. Its
# logical size is num_tokens * local value heads * value head dim.
# We avoid a dry-run because FlashQLA may do host-side syncs while
# preparing varlen chunk metadata, which is illegal during capture.
num_tokens = mixed_qkvzba.shape[0]
o_shape = (num_tokens, self.tp_num_v_heads, self.head_v_dim)
o_dtype = mixed_qkvzba.dtype
o_device = mixed_qkvzba.device
z_shape = o_shape

infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
_o = tensor_to_no_ref_tensor(o)
z = torch.empty(z_shape, dtype=o_dtype, device=o_device)
_z = tensor_to_no_ref_tensor(z)

def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba)
_z.copy_(tmp_z)
tmp_o = self._gdn_prefill_kernel(
mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
)
tmp_o = tmp_o.view(_o.shape)
_o.copy_(tmp_o)
return
Comment on lines +298 to +316
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The allocation and copy of the z tensor in the CUDA graph capture path are redundant. Since z is a slice of mixed_qkvzba (which is an output of the previous graph segment), you can obtain z as a view directly during capture. This avoids an unnecessary allocation and a GPU-to-GPU copy inside the host node during replay. The subsequent z.contiguous() call in gdn_forward will handle contiguity if required by the norm kernel, and that copy will be efficiently captured in the graph.

            infer_state.prefill_cuda_graph_create_graph_obj()
            infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
            o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
            _o = tensor_to_no_ref_tensor(o)
            _, z, _, _ = self._split_qkvzba(mixed_qkvzba, is_decode=False)

            def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo):
                conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_)
                mixed_qkv, _, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False)
                tmp_o = self._gdn_prefill_kernel(
                    mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight
                )
                _o.copy_(tmp_o.view(_o.shape))
                return


infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a likely typo in the method name prefill_cuda_graph_add_cpu_runnning_func. It contains three 'n's in 'runnning'. Please verify if this matches the intended method name in the base class or infer_state object.

return o, z

conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_)
mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba)
core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight)
return core_attn_out, z

def _split_qkvzba(self, mixed_qkvzba):
qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim
z_end = qkv_dim + self.tp_value_dim
b_end = z_end + self.tp_num_v_heads
Expand Down
9 changes: 9 additions & 0 deletions test/acc/test_qwen3.5.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \
# second
export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code

# prefill cuda graph 功能测试
LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \
--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \
--tp 2 \
--port 8089 \
--enable_prefill_cudagraph

export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code


# 测试
LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \
Expand Down
Loading