@@ -1415,24 +1415,25 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1415
1415
}
1416
1416
1417
1417
static __global__ void quantize_q8_1 (
1418
- const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int nchannels ) {
1418
+ const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int ky ) {
1419
1419
1420
1420
const int ix = blockDim .x *blockIdx .x + threadIdx .x ;
1421
1421
1422
1422
if (ix >= kx_padded) {
1423
1423
return ;
1424
1424
}
1425
1425
1426
- const int iy = blockDim .y *blockIdx .y + threadIdx .y ;
1426
+ const int iy = blockDim .y *blockIdx .y + threadIdx .y ;
1427
+ const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
1427
1428
1428
- const int i_padded = iy*kx_padded + ix;
1429
+ const int i_padded = channel*ky*kx_padded + iy*kx_padded + ix;
1429
1430
1430
1431
block_q8_1 * y = (block_q8_1 *) vy;
1431
1432
1432
1433
const int ib = i_padded / QK8_1; // block index
1433
1434
const int iqs = i_padded % QK8_1; // quant index
1434
1435
1435
- const float xi = ix < kx ? x[iy*kx + ix] : 0 .0f ;
1436
+ const float xi = ix < kx ? x[channel*ky*kx + iy*kx + ix] : 0 .0f ;
1436
1437
float amax = fabsf (xi);
1437
1438
float sum = xi;
1438
1439
@@ -4298,9 +4299,9 @@ static void quantize_row_q8_1_cuda(
4298
4299
const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels, cudaStream_t stream) {
4299
4300
4300
4301
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
4301
- const dim3 num_blocks (block_num_x, ky*nchannels, 1 );
4302
+ const dim3 num_blocks (block_num_x, ky, nchannels );
4302
4303
const dim3 block_size (CUDA_DEQUANTIZE_BLOCK_SIZE, 1 , 1 );
4303
- quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded, nchannels );
4304
+ quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded, ky );
4304
4305
}
4305
4306
4306
4307
static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
0 commit comments