Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ jobs:
with:
python-version: 3.9

- name: Setup MSVC
if: startsWith(matrix.os, 'windows')
uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl for torch.compile

- name: Install dependencies
run: |
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -201,18 +205,40 @@ jobs:
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu128"

# L40S runners

# Linux L40S runners
- os: ubuntu-22.04
gpu: L40S
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80

# T4 runners
# Linux T4 runners
- os: ubuntu-22.04
gpu: T4
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80

# Specific Windows runners using cu118
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.2.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.6.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu118"

exclude:
# Our current T4 Windows runner has a driver too old (471.11)
# and cannot support CUDA 12+. Skip for now.
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,14 +771,14 @@ def quantize_blockwise(
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(
absmax=qabsmax,
code=code,
code=code.to(A.device, copy=True),
blocksize=blocksize,
dtype=A.dtype,
offset=offset,
state2=state2,
)
else:
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)

# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def forward(self, x: torch.Tensor):

bias = None if self.bias is None else self.bias.to(self.compute_dtype)

return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)


class LinearFP4(Linear4bit):
Expand Down
92 changes: 91 additions & 1 deletion tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import copy
import os
import pickle
import platform
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
from tests.helpers import (
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
)

storage = {
"uint8": torch.uint8,
Expand Down Expand Up @@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_deserialized


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
if device == "cpu" and quant_type == "fp4":
pytest.skip("FP4 is not supported for CPU")

if fullgraph and torch.__version__ < (2, 8):
pytest.skip("fullgraph mode requires torch 2.8 or higher")

if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")

# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
if (
not fullgraph
and device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
):
pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")

dim = 256
batch_size = 16

torch.compiler.reset()

# Create a small network with Linear4bit layers
net = torch.nn.Sequential(
*[
bnb.nn.Linear4bit(
dim,
dim,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
for _ in range(4)
]
).to(device)

# Create input tensor
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)

# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)

# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)

# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)

# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)

# Test with gradients
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()

x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()

torch.testing.assert_close(grad_compiled, grad_ref)
66 changes: 66 additions & 0 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import os
import pickle
import platform
from tempfile import TemporaryDirectory

import pytest
Expand Down Expand Up @@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit):
# check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all()


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")

dim = 256
batch_size = 16

torch.compiler.reset()

# Create a small network with Linear8bitLt layers
net = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
).to(device)

dynamic_output_shapes = fullgraph and threshold > 0
with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
# Create input tensor
x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)

# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)

# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)

# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)

# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)

# Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
is_broken_platform = (
device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
)

if threshold == 0 and not is_broken_platform:
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()

x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()

torch.testing.assert_close(grad_compiled, grad_ref)