Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"cpu",
"cuda", # NVIDIA/AMD GPU
"xpu", # Intel GPU
"hpu", # Gaudi
"hpu", # Intel Gaudi
"npu", # Ascend NPU
"mps", # Apple Silicon
}
Expand All @@ -37,6 +37,9 @@
if torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch, "hpu") and torch.hpu.is_available():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Line-29 above, can we make it "Intel Gaudi".

from .backends.hpu import ops as hpu_ops


def _import_backends():
"""
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def matmul_4bit(
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
Empty file.
53 changes: 53 additions & 0 deletions bitsandbytes/backends/hpu/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections.abc import Sequence
import math

import torch

from bitsandbytes.utils import _reverse_4bit_compress_format

from ..._ops import register_kernel
from ..utils import GAUDI_SW_VER


@register_kernel("bitsandbytes::dequantize_4bit", "hpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.uint8],
lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}",
)

# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True

A = A.reshape(-1)

if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
A = _reverse_4bit_compress_format(A)

# HPU dequantization function for NF4 quantized tensors.
out_dq = torch.ops.hpu.dequantize_nf4(
A,
absmax.to(dtype),
blocksize,
out_shape=(math.prod(shape),),
out_dtype=dtype,
)

output = out_dq.reshape(shape)

if transpose:
output = output.t()

return output
23 changes: 23 additions & 0 deletions bitsandbytes/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import subprocess

from packaging import version
import torch

try:
Expand Down Expand Up @@ -55,3 +58,23 @@
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
)
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}


def get_gaudi_sw_version():
"""
Returns the installed version of Gaudi SW.
"""
output = subprocess.run(
"pip list | grep habana-torch-plugin",
shell=True,
text=True,
capture_output=True,
)
# If grep return nothing
if not output.stdout.strip():
return None

return version.parse(output.stdout.split("\n")[0].split()[-1])


GAUDI_SW_VER = get_gaudi_sw_version()
2 changes: 1 addition & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __init__(
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
self.compute_type_is_set = False if compute_dtype is None else True
self.quant_state = None
self.quant_storage = quant_storage
self.ipex_linear_is_set = False
Expand Down