Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bitsandbytes/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from ._functions import get_inverse_transform_indices, undo_layout
59 changes: 0 additions & 59 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Callable
from dataclasses import dataclass
from math import prod
from typing import Optional
import warnings
from warnings import warn

import torch
from typing_extensions import deprecated

import bitsandbytes.functional as F

Expand Down Expand Up @@ -50,66 +48,9 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)


@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: tuple[int, int],
):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation

:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
:returns: indices
"""
d1, d2 = tile_size
assert 0 < d1 * d2 < 2**64
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch.zeros_like(tile_indices)
for i in range(8):
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
permuted_tile_i = transform_tile(sample_tile_i)
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
permuted_tile_indices += ith_permuted_indices * (256**i)
if d1 * d2 < 256**i:
break # if all indices fit in i bytes, stop early
return permuted_tile_indices


_is_compiling = torch.compiler.is_compiling


@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout

:param permuted_tensor: torch tensor in a permuted layout
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
outputs[tile_indices.flatten()] = tensor
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs.reshape(rows, cols).contiguous()


@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
Expand Down
96 changes: 0 additions & 96 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,102 +1795,6 @@ def int8_mm_dequant(
return result


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_colrow_absmax(
A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None,
col_stats: Optional[torch.Tensor] = None,
nnz_block_ptr: Optional[torch.Tensor] = None,
threshold=0.0,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.

The row-wise and column-wise absmax values are determined.

For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).

<Tip>
This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
The column-wise quantization scales are not typically needed in inference scenarios.
</Tip>

Args:
A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.

Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
- `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
"""
assert A.is_floating_point()

outlier_mask = None

if row_stats is None or col_stats is None:
absA = A.abs().view(-1, A.shape[-1])

if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)

if row_stats is None:
# shape [rows]; unsqueeze(-1) gives [rows,1]
# We have a CUDA kernel for row max, but not yet for cols.
row_stats = get_row_absmax(A, threshold)

if col_stats is None:
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()

return row_stats, col_stats, outlier_mask


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.

For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).

Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.

Returns:
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
"""

assert A.dtype == torch.float16

rows = prod(A.shape[:-1])
cols = A.shape[-1]

row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)

is_on_gpu([A])

with _cuda_device_of(A):
lib.cget_row_stats(
get_ptr(A),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)

return row_stats


class COOSparseTensor:
def __init__(
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor
Expand Down
45 changes: 0 additions & 45 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1825,51 +1825,6 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
}
}

template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) {
using BlockReduceT = cub::BlockReduce<float, THREADS>;

// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.

__shared__ typename BlockReduceT::TempStorage temp_storage;

const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);

// Threads will read the row values in a striped access pattern and find a local absmax.
float row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const float absval = fabsf(row_data[i]);

// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}

// Reduce thread-local absmax across the block.
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = row_absmax;
}
}

template __global__ void
kgetRowStats<half, 1024, 0>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void
kgetRowStats<half, 1024, 1>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);

template __global__ void kInt8VectorQuant<half, 1024, 0>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
Expand Down
7 changes: 0 additions & 7 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);

template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
Expand Down
43 changes: 0 additions & 43 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -1946,49 +1946,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}
}

template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;

// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.

__shared__ typename BlockReduceT::TempStorage temp_storage;

const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);

// Threads will read the row values in a striped access pattern and find a local absmax.
float row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const float absval = fabsf(row_data[i]);

// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}

// Reduce thread-local absmax across the block.
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = row_absmax;
}
}

template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);

template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);

Expand Down
7 changes: 0 additions & 7 deletions csrc/kernels_hip.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);

template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
Expand Down
Loading