Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
build-cuda:
strategy:
matrix:
cuda_version: ["11.8.0", "12.8.1"]
cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025]
include:
- os: ubuntu-22.04
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
torch_version: ["2.7.0"]
torch_version: ["2.6.0", "2.7.0"]
include:
- os: ubuntu-22.04
arch: x86_64
Expand Down Expand Up @@ -138,9 +138,35 @@ jobs:
- name: Show installed packages
run: pip list

- name: Show environment information
run: python -m torch.utils.collect_env

- name: Run tests
run: pytest --durations=100

# cuda-aarch64-tests:
# if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
# needs: build-cuda
# strategy:
# fail-fast: false
# matrix:
# os: [ubuntu-22.04-arm]
# arch: [aarch64]
# torch_version: ["2.7.0"]
# cuda_version: ["11.8.0", "12.8.1"]

# runs-on: bandb-aws-g5g-4xlarge-plus-use1-public-80
# env:
# BNB_TEST_DEVICE: cuda
# steps:
# - name: Show GPU Information
# run: nvidia-smi

# - name: Show pip packages
# run: pip list



cuda-tests:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cuda
Expand All @@ -149,25 +175,28 @@ jobs:
matrix:
os: [ubuntu-22.04, windows-2025]
arch: [x86_64]
gpu: [T4, L4]
cuda_version: ["11.8.0", "12.8.1"]
gpu: [T4, L40S]
cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
include:
- cuda_version: "11.8.0"
torch_version: "2.4.1"
pypi_index: "https://download.pytorch.org/whl/cu118"
- cuda_version: "12.6.3"
torch_version: "2.6.0"
pypi_index: "https://download.pytorch.org/whl/cu126"
- cuda_version: "12.8.1"
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu128"

# L4 runners
# L40S runners
- os: ubuntu-22.04
gpu: L4
runner: bandb-aws-g6-4xlarge-plus-use1-public-80
gpu: L40S
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80

# T4 runners
- os: ubuntu-22.04
gpu: T4
runner: CUDA-Linux-x64
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80
- os: windows-2025
gpu: T4
runner: CUDA-Windows-x64
Expand All @@ -176,10 +205,12 @@ jobs:
# and cannot support CUDA 12+. Skip for now.
- os: windows-2025
cuda_version: "12.8.1"
- os: windows-2025
cuda_version: "12.6.3"

# No Windows L4 runners.
# No Windows L40S runners.
- os: windows-2025
gpu: L4
gpu: L40S
runs-on: ${{ matrix.runner }}
env:
BNB_TEST_DEVICE: cuda
Expand Down Expand Up @@ -210,5 +241,8 @@ jobs:
- name: Show installed packages
run: pip list

- name: Show environment information
run: python -m torch.utils.collect_env

- name: Run tests
run: pytest --durations=100
33 changes: 0 additions & 33 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,39 +929,6 @@ def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func):
# torch.cuda.synchronize()
# print(time.time() - t0)

@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
@pytest.mark.skip("No longer supported")
def test_integrated_sparse_decomp(self, dim1, dim2):
threshold = 3.0
for _ in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())

Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)

out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)

# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)

out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)

assert coo_tensor is not None

out4 = F.spmm_coo(coo_tensor, w1.t())
# idx = torch.unique(coo_tensor._indices()[1]).long()
# out4 = torch.matmul(A, w1.t())
out5 = out3 + out4

err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1

@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
Expand Down
12 changes: 6 additions & 6 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
assert l1.weight.dtype == torch.int8

l1.eval()
for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = l1(b1)
assert o1.dtype == torch.float16
Expand All @@ -139,7 +139,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8

for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
Expand All @@ -152,7 +152,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8

for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
Expand All @@ -163,7 +163,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):

mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)

for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
Expand All @@ -185,7 +185,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
.to(device)
)

for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
Expand All @@ -207,7 +207,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization,
mlp = mlp.to(device).half() # and this line triggers quantization

for i in range(100):
for i in range(4):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
Expand Down
Loading