Skip to content

Commit 0a10e11

Browse files
committed
issue/843: modified quant
1 parent ff8af1a commit 0a10e11

File tree

3 files changed

+0
-362
lines changed

3 files changed

+0
-362
lines changed

src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ __device__ void blockPerChannelQuantI8Kernel(
4141
}
4242
__syncthreads();
4343

44-
// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
4544
float global_max = global_max_f;
4645
float global_min = global_min_f;
4746

@@ -53,11 +52,9 @@ __device__ void blockPerChannelQuantI8Kernel(
5352
float inv_scale = 1.0f / scale;
5453
float zero = -global_min * inv_scale - 128.0f;
5554

56-
// 写回 scale, zero
5755
x_scale[row] = (Tdata)scale;
5856
x_zero[row] = (Tdata)zero;
5957

60-
// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
6158
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
6259

6360
float v = (float)x[tid + ind];
@@ -99,7 +96,6 @@ __device__ void blockPerChannelQuantI8SymKernel(
9996
}
10097
__syncthreads();
10198

102-
// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
10399
float global_max = global_max_f;
104100

105101
float scale = global_max / 127.0f;
@@ -109,10 +105,8 @@ __device__ void blockPerChannelQuantI8SymKernel(
109105

110106
float inv_scale = 1.0f / scale;
111107

112-
// 写回 scale, zero
113108
x_scale[row] = (Tdata)scale;
114109

115-
// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
116110
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
117111

118112
float v = (float)x[tid + ind];
@@ -183,7 +177,6 @@ __device__ void warpPerChannelQuantI8Kernel(
183177
}
184178
__syncthreads();
185179

186-
// ---- float scale/zero(与 Python float32 匹配)----
187180
float max_f = max_total[threadIdx.y];
188181
float min_f = min_total[threadIdx.y];
189182

@@ -198,7 +191,6 @@ __device__ void warpPerChannelQuantI8Kernel(
198191
x_scale[otherIdx] = scale;
199192
x_zero[otherIdx] = zero;
200193

201-
// ---- float + half-away-from-zero 量化 ----
202194
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
203195
float v = (float)x[tid + ind];
204196
float qf = v * inv_scale + zero;
@@ -243,7 +235,6 @@ __device__ void warpPerChannelQuantI8SymKernel(
243235
}
244236
__syncthreads();
245237

246-
// ---- float scale/zero(与 Python float32 匹配)----
247238
float max_f = max_total[threadIdx.y];
248239

249240
float scale = max_f / 127.0f;
@@ -255,7 +246,6 @@ __device__ void warpPerChannelQuantI8SymKernel(
255246

256247
x_scale[otherIdx] = scale;
257248

258-
// ---- float + half-away-from-zero 量化 ----
259249
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
260250
float v = (float)x[tid + ind];
261251
float qf = v * inv_scale;

test/infiniop/per_channel_quant_int8.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,15 @@ def per_token_quant_int8_torch(x, symmetric):
6565
w_min = w.min(dim=-1, keepdim=True)[0]
6666
w_max = w.max(dim=-1, keepdim=True)[0]
6767

68-
# 避免除以零
6968
w_scale = (w_max - w_min) / 255.0
7069
w_scale = torch.clamp(w_scale, min=1e-8)
7170

72-
# 计算zero point
7371
w_zero = -w_min / w_scale - 128.0
7472

75-
# 计算量化值
7673
w_q = torch.round(w / w_scale + w_zero)
7774

78-
# 限制范围[-128, 127]
7975
w_q = torch.clamp(w_q, -128, 127)
8076

81-
# 转为int8
8277
w_packed = w_q.to(torch.int8)
8378

8479
return w_packed, w_scale, w_zero

0 commit comments

Comments
 (0)