Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,133 @@ jobs:
# - name: Show pip packages
# run: pip list

test-hpu:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
strategy:
fail-fast: false
matrix:
torch_version: ["2.6.0"]
runs-on:
group: bandb-itac-bmemr-gaudi3-1gaudi
env:
BNB_TEST_DEVICE: hpu
container:
image: vault.habana.ai/gaudi-docker/1.21.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
options: --runtime=habana --shm-size=64G --env HABANA_VISIBLE_DEVICES --env HABANA_VISIBLE_MODULES
env:
OMPI_MCA_btl_vader_single_copy_mechanism: none
BNB_TEST_DEVICE: hpu
steps:
- name: Show system information
run: |
echo "OS: $(uname -a)"
echo "CPU: $(lscpu | grep 'Model name')"
echo "Memory: $(free -h)"

- name: Show HPU Information
run: |
hl-smi

- uses: actions/checkout@v4

- name: Download build artifact
uses: actions/download-artifact@v4
with:
name: lib_cpu_ubuntu-22.04_x86_64
path: bitsandbytes/
merge-multiple: true

- name: Show installed packages
run: pip list

- name: Install dependencies
run: |
pip install -e ".[test]"
pip install pytest-cov

- name: Show installed packages
run: pip list

- name: Show environment information
run: |
python -m torch.utils.collect_env
python -m bitsandbytes

- name: Run tests
run: pytest --durations=100

test-xpu:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
strategy:
fail-fast: false
matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
ipex: [false]
# ipex: [true, false]
# include:
# - torch_version: "2.6.0"
# ipex: true
# ipex_version: "2.6.10+xpu"
# - torch_version: "2.7.1"
# ipex: true
# ipex_version: "2.7.10+xpu"
runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu
env:
BNB_TEST_DEVICE: xpu
steps:
- name: Show system information
run: |
echo "OS: $(uname -a)"
echo "CPU: $(lscpu | grep 'Model name')"
echo "Memory: $(free -h)"

- name: Show XPU Information
run: |
xpu-smi discovery
sudo xpu-smi discovery
sudo apt-get install -y hwinfo
hwinfo --display

- uses: actions/checkout@v4

- name: Download build artifact
uses: actions/download-artifact@v4
with:
name: lib_cpu_ubuntu-22.04_x86_64
path: bitsandbytes/
merge-multiple: true

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9

- name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu

- name: Install IPEX
if: matrix.ipex == true
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

- name: Install dependencies
run: |
pip install -e ".[test]"
pip install pytest-cov

- name: Show installed packages
run: pip list

- name: Show environment information
run: |
python -m torch.utils.collect_env
python -m bitsandbytes

# - name: Run tests
# run: pytest --durations=100

test-cuda:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cuda
Expand Down
10 changes: 8 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import importlib
import sys

import torch
Expand Down Expand Up @@ -37,8 +38,13 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch, "hpu") and torch.hpu.is_available():
from .backends.hpu import ops as hpu_ops

if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch

if hasattr(torch, "hpu") and torch.hpu.is_available():
from .backends.hpu import ops as hpu_ops


def _import_backends():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def test_matmul_4bit(
out_bnb.data.copy_(out_torch)
if device == "cuda":
torch.cuda.synchronize()
elif device == "hpu":
torch.hpu.synchronize()

loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
Expand Down
3 changes: 2 additions & 1 deletion tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
ref_output = net(x)

# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)

# Get output from compiled model
with torch.no_grad():
Expand Down
25 changes: 21 additions & 4 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn

import bitsandbytes as bnb
from tests.helpers import get_available_devices, id_formatter
from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu


class MockArgs:
Expand Down Expand Up @@ -276,9 +276,9 @@ def test_linear_kbit_fp32_bias(device, module):
"NF4": bnb.nn.LinearNF4,
"FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
"NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
"NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
"NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
"NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
"NF4+fp32": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.float32),
"NF4+fp16": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.float16),
"NF4+bf16": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.bfloat16),
}


Expand All @@ -295,7 +295,12 @@ def test_kbit_backprop(device, module):
torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight)
ref[1].weight.requires_grad_(False)

kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])

if device == "hpu" and isinstance(kbit[1], bnb.nn.Linear4bit) and kbit[1].weight.quant_type == "fp4":
pytest.skip("FP4 is not supported on HPU")

kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
Expand Down Expand Up @@ -358,6 +363,12 @@ def test_kbit_backprop(device, module):
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "hpu":
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.skip("FP4 is not supported on HPU")
elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage):
pytest.skip("This configuration is not supported on HPU")

num_embeddings = 128

src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
Expand Down Expand Up @@ -403,6 +414,12 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim,
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "hpu":
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.skip("FP4 is not supported on HPU")
elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage):
pytest.skip("This configuration is not supported on HPU")

is_8bit = embedding_class is bnb.nn.Embedding8bit

num_embeddings = 128
Expand Down