Skip to content

Commit 64e770d

Browse files
small refactor
1 parent ac857c7 commit 64e770d

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,24 +1415,25 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
14151415
}
14161416

14171417
static __global__ void quantize_q8_1(
1418-
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int nchannels) {
1418+
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int ky) {
14191419

14201420
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
14211421

14221422
if (ix >= kx_padded) {
14231423
return;
14241424
}
14251425

1426-
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
1426+
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
1427+
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
14271428

1428-
const int i_padded = iy*kx_padded + ix;
1429+
const int i_padded = channel*ky*kx_padded + iy*kx_padded + ix;
14291430

14301431
block_q8_1 * y = (block_q8_1 *) vy;
14311432

14321433
const int ib = i_padded / QK8_1; // block index
14331434
const int iqs = i_padded % QK8_1; // quant index
14341435

1435-
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
1436+
const float xi = ix < kx ? x[channel*ky*kx + iy*kx + ix] : 0.0f;
14361437
float amax = fabsf(xi);
14371438
float sum = xi;
14381439

@@ -4298,9 +4299,9 @@ static void quantize_row_q8_1_cuda(
42984299
const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels, cudaStream_t stream) {
42994300

43004301
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
4301-
const dim3 num_blocks(block_num_x, ky*nchannels, 1);
4302+
const dim3 num_blocks(block_num_x, ky, nchannels);
43024303
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
4303-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, nchannels);
4304+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, ky);
43044305
}
43054306

43064307
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {

0 commit comments

Comments
 (0)