-
Notifications
You must be signed in to change notification settings - Fork 295
fix unit test #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix unit test #1173
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,18 @@ | |||||
| from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token | ||||||
|
|
||||||
|
|
||||||
| def is_fp8_native_supported(): | ||||||
| """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" | ||||||
| if not torch.cuda.is_available(): | ||||||
| return False | ||||||
| major, _ = torch.cuda.get_device_capability() | ||||||
| return major >= 9 | ||||||
|
|
||||||
|
|
||||||
| if not is_fp8_native_supported(): | ||||||
| pytest.skip("not support fp8 in this gpu card", allow_module_level=True) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better readability and consistency, it's good practice to use the
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("M", [1, 2, 4, 8, 16, 32, 64, 128]) | ||||||
| @pytest.mark.parametrize("N,K", [(2048, 2048), (4096, 5120), (8192, 4096)]) | ||||||
| @pytest.mark.parametrize("output_dtype", [torch.bfloat16]) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency across the test suite, it's better to use a consistent name for the kernel in the skip reason. In
test_ppl_int8kv_flash_decoding_diverse.py, the reason is 'need install lightllmKernel'. Please consider using the same capitalization here.