Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
pypi_index: "https://download.pytorch.org/whl/cu128"
- cuda_version: "12.9.1"
torch_version: "2.8.0"
pypi_index: "https://download.pytorch.org/whl/test/cu129"
pypi_index: "https://download.pytorch.org/whl/cu129"


# Linux L40S runners
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def pytest_runtest_teardown(item, nextitem):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
torch.mps.empty_cache()


@pytest.fixture(scope="session")
Expand Down
28 changes: 15 additions & 13 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
import platform
import random
import time

import einops
import numpy as np
from packaging import version
import pytest
import torch

Expand Down Expand Up @@ -101,16 +102,16 @@ class Test8BitBlockwiseQuantizeFunctional:
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
iters = 100

if device == "cpu":
if device != "cuda":
iters = 10

# This test is slow on CPU, so avoid atypical use cases.
# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
pytest.skip("Only blocksize 256 is used in CPU/XPU")
pytest.skip("Only blocksize 256 is used in CPU/MPS/XPU")
if dtype != torch.float32:
pytest.skip("Only float32 is used in CPU/XPU")
pytest.skip("Only float32 is used in CPU/MPS/XPU")

diffs = []
reldiffs = []
Expand Down Expand Up @@ -239,7 +240,7 @@ def test_fp8_quant(self, device):

abserr = []
relerr = []
for i in range(100):
for i in range(10):
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
Expand All @@ -253,7 +254,7 @@ def test_fp8_quant(self, device):

abserr = []
relerr = []
for i in range(100):
for i in range(10):
A1 = torch.rand(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
Expand All @@ -267,7 +268,7 @@ def test_fp8_quant(self, device):

abserr = []
relerr = []
for i in range(100):
for i in range(10):
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
Expand Down Expand Up @@ -1406,28 +1407,29 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
@pytest.mark.skipif(
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
reason="this test is not supported on ROCm with gfx90a architecture yet",
)
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
def test_gemv_eye_4bit(self, device, storage_type, dtype):
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")

if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
pytest.skip("This configuration is not supported on HPU.")

dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
if device == "cpu" and platform.system() == "Windows" and version.parse(torch.__version__).release == (2, 8, 0):
pytest.skip("Regression: CPU crash on Windows with torch 2.8.0")

dims = 4
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
B = torch.eye(dim, dtype=dtype, device=device)

qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=False)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
Expand Down
11 changes: 11 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def rm_path(path):
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):

if device not in ["cuda", "xpu"]:
pytest.skip("Optimizers are only supported on CUDA and XPU")

if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.")

Expand Down Expand Up @@ -253,6 +257,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_global_config(dim1, dim2, gtype, device):
if device not in ["cuda", "xpu"]:
pytest.skip("Optimizers are only supported on CUDA and XPU")

if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Expand Down Expand Up @@ -310,6 +317,10 @@ def test_global_config(dim1, dim2, gtype, device):
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):

if device not in ["cuda", "xpu"]:
pytest.skip("8-bit optimizers are only supported on CUDA and XPU")

torch.set_printoptions(precision=6)

if dim1 == 1 and dim2 == 1:
Expand Down
Loading