Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/infiniop/ops/quant/per_channel_quant_int8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__
#define __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__

#include "../../operator_descriptor.h"

typedef InfiniopDescriptor *infiniopPerChannelQuantI8Descriptor_t;

__C __export infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc);

__C __export infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *x_packed,
void *x_scale,
void *x_zero,
const void *x,
void *stream);

__C __export infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc);

#endif
277 changes: 277 additions & 0 deletions src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__

#include <cub/block/block_reduce.cuh>
__device__ inline int round_half_away_from_zero(float x) {
float ax = fabsf(x);
float r = floorf(ax + 0.5f);
return (x >= 0.0f) ? (int)r : -(int)r;
}

template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;

// ---- 1. reduce max ----
float local_max = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(
x + tid, K);

__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();

typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

// ---- 2. reduce min ----
float thread_min = __FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_min = fminf(thread_min, (float)x[tid + ind]);
}
float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min());

__shared__ float global_min_f;
if (threadIdx.x == 0) {
global_min_f = local_min;
}
__syncthreads();

// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
float global_max = global_max_f;
float global_min = global_min_f;

float scale = (global_max - global_min) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}

float inv_scale = 1.0f / scale;
float zero = -global_min * inv_scale - 128.0f;

// 写回 scale, zero
x_scale[row] = (Tdata)scale;
x_zero[row] = (Tdata)zero;

// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {

float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;

int q = round_half_away_from_zero(qf);

if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}

x_packed[tid + ind] = (int8_t)q;
}
}

template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;

typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

// ---- 2. reduce min ----
float thread_max = -__FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, fabs((float)x[tid + ind]));
}
float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max());

__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();

// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
float global_max = global_max_f;

float scale = global_max / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}

float inv_scale = 1.0f / scale;

// 写回 scale, zero
x_scale[row] = (Tdata)scale;

// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {

float v = (float)x[tid + ind];
float qf = v * inv_scale;

int q = round_half_away_from_zero(qf);

if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}

x_packed[tid + ind] = (int8_t)q;
}
}

template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return min(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}

template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;

if (otherIdx < M) {

__shared__ float max_total[BLOCK_SIZE_y];
__shared__ float min_total[BLOCK_SIZE_y];

float max_data = -__FLT_MAX__;
float min_data = __FLT_MAX__;

// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
max_data = fmaxf(max_data, v);
min_data = fminf(min_data, v);
}

max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
min_data = WarpAllReduce<MinOp, float, BLOCK_SIZE_x>(min_data);

if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
min_total[threadIdx.y] = min_data;
}
__syncthreads();

// ---- float scale/zero(与 Python float32 匹配)----
float max_f = max_total[threadIdx.y];
float min_f = min_total[threadIdx.y];

float scale = (max_f - min_f) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}

float inv_scale = 1.0f / scale;
float zero = -min_f * inv_scale - 128.0f;

x_scale[otherIdx] = scale;
x_zero[otherIdx] = zero;

// ---- float + half-away-from-zero 量化 ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;

int q = round_half_away_from_zero(qf);

if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}

x_packed[tid + ind] = (int8_t)q;
}
}
}

template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;

if (otherIdx < M) {

__shared__ float max_total[BLOCK_SIZE_y];

float max_data = -__FLT_MAX__;

// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = fabs((float)x[tid + ind]);
max_data = fmaxf(max_data, v);
}

max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);

if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
}
__syncthreads();

// ---- float scale/zero(与 Python float32 匹配)----
float max_f = max_total[threadIdx.y];

float scale = max_f / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}

float inv_scale = 1.0f / scale;

x_scale[otherIdx] = scale;

// ---- float + half-away-from-zero 量化 ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;

int q = round_half_away_from_zero(qf);

if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}

x_packed[tid + ind] = (int8_t)q;
}
}
}

#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
59 changes: 59 additions & 0 deletions src/infiniop/ops/quant/per_channel_quant_int8/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
#define __PER_CHANNEL_QUANT_INT8_INFO_H__

#include "../../../../utils.h"
#include "../../../operator.h"
#include "../../../tensor.h"

namespace op::per_channel_quant_int8 {

class PerChannelQuantI8Info {
private:
PerChannelQuantI8Info() = default;

public:
infiniDtype_t dtype, packed_type;
size_t M, K;

static utils::Result<PerChannelQuantI8Info> createPerChannelQuantI8Info(
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {

CHECK_OR_RETURN(
x_packed_desc != nullptr && x_scale_desc != nullptr && x_desc != nullptr,
INFINI_STATUS_NULL_POINTER);

const infiniDtype_t dtype = x_desc->dtype();
const infiniDtype_t packed_type = x_packed_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(packed_type, INFINI_DTYPE_I8);

CHECK_OR_RETURN(x_desc->ndim() == 2
&& x_packed_desc->ndim() == 2
&& x_scale_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);

size_t M = x_desc->dim(0);
size_t K = x_desc->dim(1);

CHECK_OR_RETURN(M == x_packed_desc->dim(0)
|| K == x_packed_desc->dim(1)
|| M == x_scale_desc->dim(0)
|| 1 == x_scale_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);

return utils::Result<PerChannelQuantI8Info>(PerChannelQuantI8Info{
dtype,
packed_type,
M,
K,
});
}
};

} // namespace op::per_channel_quant_int8

#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__
Loading