Skip to content

Commit 6ff96b1

Browse files
small refactor
1 parent 64e770d commit 6ff96b1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ggml-cuda.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,8 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
14151415
}
14161416

14171417
static __global__ void quantize_q8_1(
1418-
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int ky) {
1418+
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int ky,
1419+
const int row_stride, const int channel_stride) {
14191420

14201421
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
14211422

@@ -1433,7 +1434,7 @@ static __global__ void quantize_q8_1(
14331434
const int ib = i_padded / QK8_1; // block index
14341435
const int iqs = i_padded % QK8_1; // quant index
14351436

1436-
const float xi = ix < kx ? x[channel*ky*kx + iy*kx + ix] : 0.0f;
1437+
const float xi = ix < kx ? x[channel*channel_stride + iy*row_stride + ix] : 0.0f;
14371438
float amax = fabsf(xi);
14381439
float sum = xi;
14391440

@@ -4301,7 +4302,7 @@ static void quantize_row_q8_1_cuda(
43014302
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
43024303
const dim3 num_blocks(block_num_x, ky, nchannels);
43034304
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
4304-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, ky);
4305+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, ky, kx, ky*kx);
43054306
}
43064307

43074308
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {

0 commit comments

Comments
 (0)