Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 0 additions & 216 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
return torch.tensor(data, dtype=torch.float32)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
q = q.tolist()
q.append(0)

gap = 256 - len(q)
for i in range(gap):
q.append(0)

q.sort()

q = Tensor(q)
q = q / q.abs().max()
return q


def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.

Expand Down Expand Up @@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
return ct.c_void_p(A.data_ptr())


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def estimate_quantiles(
A: Tensor,
out: Optional[torch.Tensor] = None,
offset: float = 1 / 512,
num_quantiles=256,
) -> Tensor:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.

Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
and the extreme quantiles close to 0 and 1 have high variance / large estimation
errors. These large errors can be avoided by using the offset variable which trims
the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
usually has a much lower error but is not a minimum entropy encoding. Given an offset
of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.

Parameters
----------
A : torch.Tensor
The input tensor. Any shape.
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.

Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
"""
if A.numel() < 256:
raise NotImplementedError(
f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
)
if num_quantiles > 256:
raise NotImplementedError(
f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
)
if num_quantiles < 256 and offset == 1 / (512):
# override default arguments
offset = 1 / (2 * num_quantiles)

if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)

with _cuda_device_of(A):
is_on_gpu([A, out])

if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementedError(f"Not supported data type {A.dtype}")

if num_quantiles < 256:
step = round(256 / num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]

return out


class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""

Expand Down Expand Up @@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
return current_gnorm, clip_value, gnorm_scale


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
assert index1.dtype == torch.int32
assert index2.dtype == torch.int32

assert histogram.device.type == "cuda"
assert index1.device.type == "cuda"
assert index2.device.type == "cuda"
assert source.device.type == "cuda"

maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)


def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized():
torch.cuda.init()
Expand Down Expand Up @@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0


@deprecated(
"This function is deprecated and will be removed in a future release. "
"Consider using `int8_vectorwise_quant` instead.",
category=FutureWarning,
)
def vectorwise_quant(x, dim=1, quant_type="vector"):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type in ["vector", "row"]:
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x * (C / max1)).to(torch.int8)
return xq, max1
elif quant_type == "zeropoint":
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 255.0 / dyna
minx = x.min()
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
dyna[dyna == 0] = 1
qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type == "truncated-vector":
with torch.no_grad():
absx = torch.abs(x)
max1 = torch.amax(absx, dim=dim, keepdim=True)
max1 = max1 * 0.7
idx = absx > max1.expand_as(absx)
sign = torch.sign(x[idx])
x[idx] = max1.expand_as(absx)[idx] * sign
xq = torch.round(x / max1 * C).to(torch.int8)
return xq, max1
else:
return None


@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
if quant_type == "linear":
norm = S1 * S2 / (C * C)
# double cast needed to prevent overflows
return (xq.float() * norm).to(dtype)
elif quant_type == "zeropoint":
norm = 1.0 / (S1 * S2)
return (xq.float() * norm).to(dtype)
elif quant_type == "row-zeropoint":
norm = 1.0 / (S1 * S2)
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= norm
else:
x *= norm
return x.to(dtype)
elif quant_type == "vector-zeropoint":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= 1.0 / S1
else:
x *= 1.0 / S1
x *= 1.0 / S2.t()
return x.to(dtype)
elif quant_type == "row":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1 * S2 / (C * C)
else:
x *= S1 * S2 / (C * C)
return x.to(dtype)
elif quant_type in ["truncated-vector", "vector"]:
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1 / C
else:
x *= S1 / C
x *= S2 / C
return x.to(dtype)
else:
return None


def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state

Expand Down
89 changes: 0 additions & 89 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -357,92 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
}
}


__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
const int numThreads = blockDim.x*gridDim.x;

for(int i = tid; i < n; i+=numThreads)
{
int idx = (index1[i]*maxidx1) + index2[i];
atomicAdd(&histogram[idx], src[i]);
}
}

#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096

template<typename T>
__launch_bounds__(THREADS_ESTIMATE, 1)
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
{
const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));

T vals[NUM_ESTIMATE];

typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

__shared__ union {
typename LoadFloat::TempStorage loadf;
typename BlockRadixSort::TempStorage sort;
int smem_qidx[BLOCK_ESTIMATE];
} temp_storage;

for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
{
valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;

// do not process half-blocks
if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }

#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = max_val;

__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);

#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = ((float)vals[j]) * reciprocal_num_blocks;


__syncthreads();
// sort into striped pattern to mitigate bank conflicts
// striped pattern index for thread 0 [0, 1024, 2048, 3096]
// striped pattern index for thread 1 [1, 1025, 2049, 3097]
BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);

__syncthreads();
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
temp_storage.smem_qidx[j] = -1;

__syncthreads();

if(threadIdx.x < 256)
{
float q_interval = (1.0f-(2.0f*offset))/255.0f;
int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
temp_storage.smem_qidx[local_idx] = threadIdx.x;
}

__syncthreads();

for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
{
if(temp_storage.smem_qidx[i] != -1)
atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
}
}
}


__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
Expand Down Expand Up @@ -2998,9 +2912,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);

template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);

#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
Expand Down
6 changes: 0 additions & 6 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#define kernels


template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);

__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);

Expand Down Expand Up @@ -106,10 +104,6 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi

template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);


__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);


template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
Expand Down
Loading