Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)


def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)

triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)

out = out.reshape(A.shape)

return out, absmax.float()
Expand All @@ -35,13 +39,14 @@ def dequantize_blockwise(
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")

out = torch.empty_like(A, dtype=dtype, device=A.device)
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)

return out

Expand All @@ -55,13 +60,14 @@ def dequantize_blockwise_inplace(
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")

triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)


def quantize_4bit(
Expand All @@ -84,9 +90,10 @@ def quantize_4bit(
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)

triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out

if quant_storage != torch.uint8:
Expand Down Expand Up @@ -119,7 +126,9 @@ def dequantize_4bit(

out = torch.empty(shape, dtype=dtype, device=A.device)

triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)

return out


Expand All @@ -134,7 +143,8 @@ def dequantize_4bit_inplace(
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)


def gemv_4bit(
Expand All @@ -150,14 +160,15 @@ def gemv_4bit(

B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)

triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
code,
dtype=A.dtype,
out=B_dq_triton,
)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
code,
dtype=A.dtype,
out=B_dq_triton,
)

return torch.nn.functional.linear(
A,
Expand Down