Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

pytest.skip(reason="need install lightllmKernel", allow_module_level=True)

import torch
from lightllm.utils.light_utils import light_ops

Expand All @@ -21,15 +24,15 @@ class MockInferState:
def __init__(
self,
batch_size,
max_len_in_batch,
max_kv_seq_len,
req_to_tokens,
b_req_idx,
b_seq_len,
b_shared_seq_len=None,
b_mark_shared_group=None,
):
self.batch_size = batch_size
self.max_len_in_batch = max_len_in_batch
self.max_kv_seq_len = max_kv_seq_len
self.req_manager = MockReqManager(req_to_tokens)
self.b_req_idx = b_req_idx
self.b_seq_len = b_seq_len
Expand All @@ -44,10 +47,11 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding
与 ppl_int8kv_flash_decoding (baseline) 的对比。
"""
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import (

from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import (
token_decode_attention_flash_decoding as diverse_attention,
)
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import (
from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import (
token_decode_attention_flash_decoding as baseline_attention,
)

Expand Down Expand Up @@ -87,7 +91,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
# 创建 baseline 的 infer_state (不需要 b_shared_seq_len)
baseline_infer_state = MockInferState(
batch_size=batch_size,
max_len_in_batch=seq_len,
max_kv_seq_len=seq_len,
req_to_tokens=req_to_tokens,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
Expand All @@ -96,7 +100,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
# 创建 diverse 的 infer_state
diverse_infer_state = MockInferState(
batch_size=batch_size,
max_len_in_batch=seq_len,
max_kv_seq_len=seq_len,
req_to_tokens=req_to_tokens,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
Expand All @@ -108,8 +112,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
baseline_out = baseline_attention(
q=q.clone(),
infer_state=baseline_infer_state,
q_head_num=num_heads,
head_dim=head_dim,
cache_k=cache_k,
cache_k_scale=cache_k_scale,
cache_v=cache_v,
Expand All @@ -120,8 +122,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
diverse_out = diverse_attention(
q=q.clone(),
infer_state=diverse_infer_state,
q_head_num=num_heads,
head_dim=head_dim,
cache_k=cache_k,
cache_k_scale=cache_k_scale,
cache_v=cache_v,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import torch
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1
from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage1 import (
flash_decode_stage1,
)


@pytest.fixture
Expand Down Expand Up @@ -81,7 +83,7 @@ def test_flash_decode_stage1_execution(setup_tensors):
new_k = k.to(q.dtype)
new_v = v.to(q.dtype)

from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import (
from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import (
flash_decode_stage1 as gqa_flash_decode_stage1,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

pytest.skip(reason="need install lightllmkernel", allow_module_level=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency across the test suite, it's better to use a consistent name for the kernel in the skip reason. In test_ppl_int8kv_flash_decoding_diverse.py, the reason is 'need install lightllmKernel'. Please consider using the same capitalization here.

Suggested change
pytest.skip(reason="need install lightllmkernel", allow_module_level=True)
pytest.skip(reason="need install lightllmKernel", allow_module_level=True)


import torch
from lightllm.utils.light_utils import light_ops

Expand Down Expand Up @@ -94,7 +97,7 @@ def test_flash_decode_stage2_execution(shared_seq_len):
b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"]
req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :]

from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import (
from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import (
flash_decode_stage1 as gqa_flash_decode_stage1,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import torch
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3
from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage3 import (
flash_diverse_decode_stage3,
)


@pytest.mark.parametrize(
Expand All @@ -23,7 +25,10 @@ def test_flash_diverse_decode_stage3(batch, head_num, seq_len, shared_seq_len, b
flash_diverse_decode_stage3(mid_out, mid_out_logexpsum, B_Seqlen, b_shared_seq_len, out, block_seq)

true_out = torch.zeros_like(out)
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2

from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding_stage2 import (
flash_decode_stage2,
)

flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, true_out, block_seq)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import torch.nn.functional as F
import flashinfer
from lightllm.utils.log_utils import init_logger
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.req_manager import ReqManager

logger = init_logger(__name__)

Expand Down Expand Up @@ -54,25 +53,25 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim):

infer_state = LlamaInferStateInfo()
infer_state.batch_size = Z
infer_state.max_len_in_batch = N_CTX
infer_state.max_q_seq_len = N_CTX
infer_state.total_token_num = Z * N_CTX
infer_state.req_manager = ReqManager(Z, N_CTX, None)
infer_state.req_manager = type("Object", (), {})()
infer_state.req_manager.req_to_token_indexs = req_to_token_indexs
infer_state.b_req_idx = b_req_idx
infer_state.b_seq_len = b_seq_len
infer_state.b_ready_cache_len = b_ready_cache_len
infer_state.b_start_loc = q_start_loc
infer_state.b_q_start_loc = q_start_loc

context_attention_fwd(
q,
kv[:, :KV_HEADS, :],
kv[:, KV_HEADS:, :],
o,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_q_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.max_q_seq_len,
infer_state.req_manager.req_to_token_indexs,
)

Expand Down Expand Up @@ -127,7 +126,11 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim):
"batch, seqlen, q_heads, kv_heads, head_dim",
[
(a, b, c, d, e)
for a in [1, 16, 32, 128, 512]
for a in [
1,
16,
32,
]
for b in [16, 32, 512, 1024]
for c in [28]
for d in [4]
Expand All @@ -149,18 +152,18 @@ def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads,

infer_state = LlamaInferStateInfo()
infer_state.batch_size = Z
infer_state.max_len_in_batch = N_CTX
infer_state.max_q_seq_len = N_CTX
infer_state.b_seq_len = b_seq_len
infer_state.b_start_loc = b_start_loc
infer_state.b_q_start_loc = b_start_loc

context_attention_fwd_no_prompt_cache(
q,
k,
v,
o,
infer_state.b_start_loc,
infer_state.b_q_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
infer_state.max_q_seq_len,
)

head_dim = HEAD_DIM
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import pytest
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.common.basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv
from lightllm.utils.log_utils import init_logger
import torch.nn.functional as F

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import torch.nn.functional as F
import flashinfer
from lightllm.utils.log_utils import init_logger
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
from lightllm.common.basemodel.triton_kernel.mla_att.decode_att.gqa_flash_decoding import (
gqa_token_decode_attention_flash_decoding,
)
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.common.req_manager import ReqManager

logger = init_logger(__name__)

Expand Down Expand Up @@ -53,7 +54,7 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head):
infer_state.batch_size = Z
infer_state.max_len_in_batch = N_CTX
infer_state.total_token_num = Z * N_CTX
infer_state.req_manager = ReqManager(Z, N_CTX, None)
infer_state.req_manager = type("Object", (), {})()
infer_state.req_manager.req_to_token_indexs = req_to_token_indexs
infer_state.b_req_idx = b_req_idx
infer_state.b_seq_len = b_seq_len
Expand All @@ -67,10 +68,6 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head):
kv_nope,
kv_rope,
infer_state,
H,
D_HEAD,
ROPE_HEAD,
D_HEAD,
sm_scale,
o,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def test_add_in_place():
assert input.item() == 3, "最终值应为 3"


@pytest.mark.timeout(2)
def test_wait_timeout():
input = torch.zeros((1,), device="cuda", dtype=torch.int32)
wait_value(input, 4)
# @pytest.mark.timeout(2)
# def test_wait_timeout():
# input = torch.zeros((1,), device="cuda", dtype=torch.int32)
# wait_value(input, 4)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_token_id_counter():
for _ in range(100):
token_id_counter(prompt_ids=test_prompt_ids, out_token_id_counter=test_token_id_counter)
end_event.record()
end_event.synchronize()
logger.info(f"test_token_id_count cost time: {start_event.elapsed_time(end_event)} ms")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import pytest
from lightllm.utils.log_utils import init_logger
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index

logger = init_logger(__name__)

Expand Down
9 changes: 6 additions & 3 deletions unit_tests/common/fused_moe/test_deepep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest

pytest.skip(reason="need special env, install deep_ep and deep_gemm", allow_module_level=True)

import os
import torch
import torch.distributed as dist
import pytest
import deep_ep
import random
import numpy as np
from deep_ep import Buffer, EventOverlap
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
from typing import Tuple
Expand All @@ -25,6 +26,8 @@
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
from deep_gemm import ceil_div

x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import torch
import time
import pytest


def is_fp8_native_supported():
"""检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)"""
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major >= 9


if not is_fp8_native_supported():
pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True)

import random
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
Expand Down
5 changes: 4 additions & 1 deletion unit_tests/common/fused_moe/test_softmax_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@


def benchmark(M, N, K, renorm, runs):
import sgl_kernel as sgl_ops
try:
import sgl_kernel as sgl_ops
except Exception as e:
pytest.skip(f"no sgl_kernel error: {str(e)}", allow_module_level=True)

gating = torch.randn(M, N, device="cuda", dtype=torch.float32)
torch.cuda.synchronize()
Expand Down
12 changes: 12 additions & 0 deletions unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token


def is_fp8_native_supported():
"""检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)"""
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major >= 9


if not is_fp8_native_supported():
pytest.skip("not support fp8 in this gpu card", allow_module_level=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and consistency, it's good practice to use the reason keyword argument for pytest.skip. Also, to align with other tests checking for FP8 support (e.g., in test_moe_silu_and_mul_mix_quant_ep.py), consider using a more descriptive and consistent reason like 'not support fp8 test in this gpu card'.

Suggested change
pytest.skip("not support fp8 in this gpu card", allow_module_level=True)
pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True)



@pytest.mark.parametrize("M", [1, 2, 4, 8, 16, 32, 64, 128])
@pytest.mark.parametrize("N,K", [(2048, 2048), (4096, 5120), (8192, 4096)])
@pytest.mark.parametrize("output_dtype", [torch.bfloat16])
Expand Down
Loading