Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _(
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
ct.c_int32(n),
)

if A.dtype == torch.bfloat16:
Expand Down Expand Up @@ -403,7 +403,7 @@ def _dequantize_4bit_impl(
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
ct.c_int32(out.numel()),
_get_tensor_stream(A),
)

Expand Down
15 changes: 9 additions & 6 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,16 @@ __global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
const int n_full = gridDim.x * BLOCK_SIZE;
// This can overflow, so we clamp to INT32_MAX. We won't have more elements than this.
const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX);

const int base_idx = blockIdx.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);

T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH];
// float local_abs_max = -FLT_MAX;

float local_abs_max = 0.0f;
int local_rand_idx = 0;

Expand All @@ -358,8 +360,8 @@ __global__ void kQuantizeBlockwise(
for (int i = threadIdx.x; i < 256; i += blockDim.x)
smem_code[i] = code[i];

for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = min(BLOCK_SIZE, static_cast<int>(n - i));
local_abs_max = -FLT_MAX;

__syncthreads();
Expand Down Expand Up @@ -442,7 +444,8 @@ __global__ void

for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
if (DATA_TYPE > 0) {
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
// Cast n to int64_t to avoid overflow for large n
valid_items_load = min(TILE_SIZE, static_cast<int>((static_cast<int64_t>(n) + 1) / 2) - i);
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
} else {
valid_items_load = min(TILE_SIZE, n - i);
Expand Down
13 changes: 7 additions & 6 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream
) {
// printf("stream==%d\n",stream);
int num_blocks = n / blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512;

// Upcast to int64 to avoid overflow for large n
int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size;

if (DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
<<<grid_blocks, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
else
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
<<<grid_blocks, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);

CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
Expand Down
72 changes: 65 additions & 7 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,34 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
assert relerr < 0.012
assert A2.dtype == dtype

@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("blocksize", [256], ids=id_formatter("blocksize"))
def test_dynamic_blockwise_quantization_large(self, device, dtype, blocksize):
"""
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
- On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels.
- On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test.
- Verification of the accuracy for dequantization has too high memory overhead for this test.
"""
if device not in ["cuda", "xpu"]:
pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.")

data = torch.randn(2**31 - 1, device=device, dtype=dtype)
q_data, q_stats = F.quantize_blockwise(data, blocksize=blocksize)

assert q_data is not None
assert q_data.dtype == torch.uint8
assert q_data.numel() == data.numel()

# Dequant
del data
dq = F.dequantize_blockwise(q_data, q_stats)

assert dq.dtype == dtype
assert dq.numel() == q_data.numel()

@pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required")
@pytest.mark.parametrize("hidden", [128])
@pytest.mark.parametrize("blocksize", [4096, 16384])
Expand Down Expand Up @@ -1118,18 +1146,17 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
del qa, SA

assert A2.dtype == dtype

err = (A1 - A2).abs().float()
del A2

relerr = (err / (A1.abs().float() + 1e-8)).mean()
err = err.mean()

assert A2.dtype == dtype

# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
#
# Actually, the above is not true anymore after fixing the integer packing bug.
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
# The following values were taken from averaging 1k samples per test configuration.
error_dict = dict()
error_dict["fp4"] = dict()
error_dict["nf4"] = dict()
Expand Down Expand Up @@ -1213,6 +1240,37 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
assert err.item() < 0.11
assert relerr.item() < 0.28

@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
"""
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
- On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels.
- On CUDA, this test requires ~10GiB of memory for fp32
- On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test.
- Verification of the accuracy for dequantization has too high memory overhead for this test.
"""

if device not in ["cuda", "xpu"]:
pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.")

A1 = torch.randn(2**31 - 1, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)

assert qa is not None
assert qa.dtype == torch.uint8
assert qa.numel() == (2**31 - 1 + 1) // 2 # each byte holds 2 quantized values

# Dequant
del A1
dq = F.dequantize_4bit(qa, SA)

assert dq.dtype == dtype
assert dq.numel() == 2**31 - 1

# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
Expand Down