Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 124 additions & 78 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,6 @@ __device__ float atomicMax(float* address, float val) {
return __int_as_float(old);
}

__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 111
return 0.25000000f * absmax * sign; // 1111
else
return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else
return 0.33333333f * absmax * sign; // 1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else
return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else
return 0.00000000f * absmax * sign; // 1000
}

__device__ unsigned char dQuantizeFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
Expand Down Expand Up @@ -118,52 +95,6 @@ __device__ unsigned char dQuantizeFP4(float x) {
return 0b0000 + sign;
}

__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {

// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4) // 1
if ((val & 0b0010) == 2) // 11
if ((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else if ((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else if ((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;

else if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else if ((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else if ((val & 0b0010) == 2) // 00
if ((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else if ((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}

__device__ unsigned char dQuantizeNF4(float x) {

// the values for this tree was generated by test_normal_map_tree
Expand Down Expand Up @@ -468,6 +399,8 @@ template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {

const int lane_id = threadIdx.x & 31;

const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
Expand All @@ -483,8 +416,122 @@ __global__ void
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;

// Each thread in the warp holds one 4-bit LUT value for cooperative shuffling
float my_lut_val;
if constexpr (DATA_TYPE == NF4) {
// NF4 lookup table
switch (lane_id & 0xF) {
case 0:
my_lut_val = -1.0f;
break;
case 1:
my_lut_val = -0.6961928009986877f;
break;
case 2:
my_lut_val = -0.5250730514526367f;
break;
case 3:
my_lut_val = -0.39491748809814453f;
break;
case 4:
my_lut_val = -0.28444138169288635f;
break;
case 5:
my_lut_val = -0.18477343022823334f;
break;
case 6:
my_lut_val = -0.09105003625154495f;
break;
case 7:
my_lut_val = 0.0f;
break;
case 8:
my_lut_val = 0.07958029955625534f;
break;
case 9:
my_lut_val = 0.16093020141124725f;
break;
case 10:
my_lut_val = 0.24611230194568634f;
break;
case 11:
my_lut_val = 0.33791524171829224f;
break;
case 12:
my_lut_val = 0.44070982933044434f;
break;
case 13:
my_lut_val = 0.5626170039176941f;
break;
case 14:
my_lut_val = 0.7229568362236023f;
break;
case 15:
my_lut_val = 1.0f;
break;
default:
my_lut_val = 0.0f;
break;
}
} else if constexpr (DATA_TYPE == FP4) {
// FP4 lookup table
switch (lane_id & 0xF) {
case 0:
my_lut_val = 0.00000000f;
break;
case 1:
my_lut_val = 0.00520833f;
break;
case 2:
my_lut_val = 0.66666667f;
break;
case 3:
my_lut_val = 1.00000000f;
break;
case 4:
my_lut_val = 0.33333333f;
break;
case 5:
my_lut_val = 0.50000000f;
break;
case 6:
my_lut_val = 0.16666667f;
break;
case 7:
my_lut_val = 0.25000000f;
break;
case 8:
my_lut_val = 0.00000000f;
break;
case 9:
my_lut_val = -0.00520833f;
break;
case 10:
my_lut_val = -0.66666667f;
break;
case 11:
my_lut_val = -1.00000000f;
break;
case 12:
my_lut_val = -0.33333333f;
break;
case 13:
my_lut_val = -0.50000000f;
break;
case 14:
my_lut_val = -0.16666667f;
break;
case 15:
my_lut_val = -0.25000000f;
break;
default:
my_lut_val = 0.00000000f;
break;
}
}

for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
if (DATA_TYPE > 0) {
if constexpr (DATA_TYPE > 0) {
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
} else {
Expand All @@ -508,17 +555,16 @@ __global__ void
vals[j] = __ldg(&code[qvals[j]]) * local_abs_max;
break;
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
}
break;
case NF4:
// Each warp will cooperatively shuffle the LUT values
// so that each thread has access to all 16 possible values.
// This avoids the need for shared memory and branches.
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
const unsigned char high_val = qvals[j] >> 4;
const unsigned char low_val = qvals[j] & 0x0F;
vals[j * 2] = __shfl_sync(0xFFFFFFFF, my_lut_val, high_val) * local_abs_max;
vals[j * 2 + 1] = __shfl_sync(0xFFFFFFFF, my_lut_val, low_val) * local_abs_max;
}
break;
}
Expand Down
Loading