Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样
import torch
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
from lightllm.common.basemodel.infer_struct import InferStateInfo
from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1
from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2
Expand Down

This file was deleted.

Empty file.

This file was deleted.

86 changes: 16 additions & 70 deletions lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from lightllm.utils.sgl_utils import sgl_ops
from lightllm.utils.light_utils import light_ops
from typing import Callable, List, Optional, Tuple
from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk
from lightllm.common.triton_utils.autotuner import Autotuner

use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]


def fused_topk(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -127,44 +123,6 @@ def biased_grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


# This is used by the Deepseek-V2 model
def cuda_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert light_ops is not None, "lightllm_kernel is not installed."

num_tokens = gating_output.shape[0]
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
if correction_bias is None:
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
light_ops.grouped_topk(
topk_weights,
correction_bias,
topk_indices,
token_expert_indices,
gating_output.float(),
num_expert_group,
topk_group,
topk,
renormalize,
scoring_func,
group_scores,
)

return topk_weights, topk_indices


def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -184,34 +142,22 @@ def select_experts(
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if use_cuda_grouped_topk:
topk_weights, topk_ids = cuda_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
)
else:
group_score_topk_num = 1
# for deepseek v3
if topk_group == 4 and num_expert_group == 8 and top_k == 8:
group_score_topk_num = 2

topk_weights, topk_ids = triton_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
group_score_used_topk_num=group_score_topk_num,
)
group_score_topk_num = 1
# for deepseek v3
if topk_group == 4 and num_expert_group == 8 and top_k == 8:
group_score_topk_num = 2

topk_weights, topk_ids = triton_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
group_score_used_topk_num=group_score_topk_num,
)

elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(
Expand Down
11 changes: 2 additions & 9 deletions lightllm/common/quantization/w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8
from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops


from .quantize_method import WeightPack

if HAS_LIGHTLLM_KERNEL:

def scaled_fp8_quant(tensor, *args, **kwargs):
return light_ops.per_token_quant_bf16_fp8(tensor)

else:
if HAS_VLLM:
scaled_fp8_quant = vllm_ops.scaled_fp8_quant
if HAS_VLLM:
scaled_fp8_quant = vllm_ops.scaled_fp8_quant

LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [
"ON",
Expand Down
13 changes: 0 additions & 13 deletions lightllm/utils/light_utils.py

This file was deleted.

5 changes: 5 additions & 0 deletions test/acc/test_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --mo
# second
export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code

# test quant
LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --quant_type vllm-fp8w8a8

export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code


LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_tpsp_mix_mode

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

import torch
from lightllm.utils.light_utils import light_ops


def alloc_tensor_func(shape, dtype, device):
Expand Down Expand Up @@ -41,17 +40,17 @@ def __init__(
# @pytest.mark.parametrize("shared_seq_len", [512])
@pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550])
@pytest.mark.parametrize("batch_size", list(range(6, 121, 6)))
def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size):
def test_token_decode_attention_flash_decoding_diverse_matches_normal_decode(shared_seq_len, batch_size):
"""
测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding
与 ppl_int8kv_flash_decoding (baseline) 的对比
diverse 与 normal 均为仓库内 Triton 实现,应数值一致(无外部 CUDA extension)。
diverse:int8kv_flash_decoding_diverse;对照:int8kv/normal token_decode_attention_flash_decoding
"""

from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import (
token_decode_attention_flash_decoding as diverse_attention,
)
from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import (
token_decode_attention_flash_decoding as baseline_attention,
from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.normal import (
token_decode_attention_flash_decoding as normal_decode,
)

num_heads = 32
Expand Down Expand Up @@ -89,7 +88,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
b_mark_shared_group[mark_shared_group_size - 1 :: mark_shared_group_size] = mark_shared_group_size

# 创建 baseline 的 infer_state (不需要 b_shared_seq_len)
# 标准 int8 decode(单路径 Triton)
baseline_infer_state = MockInferState(
batch_size=batch_size,
max_kv_seq_len=max_len_in_batch,
Expand All @@ -98,7 +97,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
b_seq_len=b_seq_len,
)

# 创建 diverse 的 infer_state
# diverse:多流 + 共享前缀(Triton)
diverse_infer_state = MockInferState(
batch_size=batch_size,
max_kv_seq_len=max_len_in_batch,
Expand All @@ -110,7 +109,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
)

# 运行 baseline
baseline_out = baseline_attention(
normal_out = normal_decode(
q=q.clone(),
infer_state=baseline_infer_state,
cache_k=cache_k,
Expand All @@ -131,11 +130,10 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
)

print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}")
print(f"baseline_out: {baseline_out[0, 0, :4]}")
print(f"normal_out: {normal_out[0, 0, :4]}")
print(f"diverse_out: {diverse_out[0, 0, :4]}")
print(f"max diff: {(baseline_out - diverse_out).abs().max()}")
print(f"max diff: {(normal_out - diverse_out).abs().max()}")

# 与 baseline 对比
assert torch.allclose(
baseline_out, diverse_out, atol=1e-2, rtol=1e-2
), f"Diverse attention output should match baseline for shared_seq_len={shared_seq_len}"
normal_out, diverse_out, atol=1e-2, rtol=1e-2
), f"diverse vs normal decode mismatch for shared_seq_len={shared_seq_len}"
Loading
Loading