@@ -41,7 +41,6 @@ __device__ void blockPerChannelQuantI8Kernel(
4141 }
4242 __syncthreads ();
4343
44- // ---- 3. 使用 float(匹配 python)计算 scale/zero ----
4544 float global_max = global_max_f;
4645 float global_min = global_min_f;
4746
@@ -53,11 +52,9 @@ __device__ void blockPerChannelQuantI8Kernel(
5352 float inv_scale = 1 .0f / scale;
5453 float zero = -global_min * inv_scale - 128 .0f ;
5554
56- // 写回 scale, zero
5755 x_scale[row] = (Tdata)scale;
5856 x_zero[row] = (Tdata)zero;
5957
60- // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
6158 for (int ind = threadIdx .x ; ind < K; ind += BLOCK_SIZE) {
6259
6360 float v = (float )x[tid + ind];
@@ -99,7 +96,6 @@ __device__ void blockPerChannelQuantI8SymKernel(
9996 }
10097 __syncthreads ();
10198
102- // ---- 3. 使用 float(匹配 python)计算 scale/zero ----
10399 float global_max = global_max_f;
104100
105101 float scale = global_max / 127 .0f ;
@@ -109,10 +105,8 @@ __device__ void blockPerChannelQuantI8SymKernel(
109105
110106 float inv_scale = 1 .0f / scale;
111107
112- // 写回 scale, zero
113108 x_scale[row] = (Tdata)scale;
114109
115- // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
116110 for (int ind = threadIdx .x ; ind < K; ind += BLOCK_SIZE) {
117111
118112 float v = (float )x[tid + ind];
@@ -183,7 +177,6 @@ __device__ void warpPerChannelQuantI8Kernel(
183177 }
184178 __syncthreads ();
185179
186- // ---- float scale/zero(与 Python float32 匹配)----
187180 float max_f = max_total[threadIdx .y ];
188181 float min_f = min_total[threadIdx .y ];
189182
@@ -198,7 +191,6 @@ __device__ void warpPerChannelQuantI8Kernel(
198191 x_scale[otherIdx] = scale;
199192 x_zero[otherIdx] = zero;
200193
201- // ---- float + half-away-from-zero 量化 ----
202194 for (int ind = threadIdx .x ; ind < K; ind += BLOCK_SIZE_x) {
203195 float v = (float )x[tid + ind];
204196 float qf = v * inv_scale + zero;
@@ -243,7 +235,6 @@ __device__ void warpPerChannelQuantI8SymKernel(
243235 }
244236 __syncthreads ();
245237
246- // ---- float scale/zero(与 Python float32 匹配)----
247238 float max_f = max_total[threadIdx .y ];
248239
249240 float scale = max_f / 127 .0f ;
@@ -255,7 +246,6 @@ __device__ void warpPerChannelQuantI8SymKernel(
255246
256247 x_scale[otherIdx] = scale;
257248
258- // ---- float + half-away-from-zero 量化 ----
259249 for (int ind = threadIdx .x ; ind < K; ind += BLOCK_SIZE_x) {
260250 float v = (float )x[tid + ind];
261251 float qf = v * inv_scale;
0 commit comments