Skip to content

Commit ac857c7

Browse files
small refactor
1 parent 45e482d commit ac857c7

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

ggml-cuda.cu

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,9 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
14141414
v.y = x[ib + iqs + 1];
14151415
}
14161416

1417-
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
1417+
static __global__ void quantize_q8_1(
1418+
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int nchannels) {
1419+
14181420
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
14191421

14201422
if (ix >= kx_padded) {
@@ -4292,11 +4294,13 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
42924294
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
42934295
}
42944296

4295-
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
4297+
static void quantize_row_q8_1_cuda(
4298+
const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels, cudaStream_t stream) {
4299+
42964300
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
4297-
const dim3 num_blocks(block_num_x, ky, 1);
4301+
const dim3 num_blocks(block_num_x, ky*nchannels, 1);
42984302
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
4299-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
4303+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, nchannels);
43004304
}
43014305

43024306
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -5552,7 +5556,7 @@ inline void ggml_cuda_op_mul_mat_q(
55525556
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
55535557
size_t as;
55545558
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*nchannels*sizeof(block_q8_1)/QK8_1, &as);
5555-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11*nchannels, padded_row_size, cudaStream_main);
5559+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels, cudaStream_main);
55565560

55575561
// const int row_stride = nb01 / ggml_type_size(src0->type);
55585562
const int row_stride = src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU &&
@@ -5706,7 +5710,7 @@ inline void ggml_cuda_op_mul_mat_vec(
57065710
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
57075711
size_t as;
57085712
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as);
5709-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne02, padded_row_size, cudaStream_main);
5713+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, cudaStream_main);
57105714

57115715
const int row_delta = nb01 / ggml_type_size(src0->type);
57125716
const int channel_delta = nb02 / ggml_type_size(src0->type);

0 commit comments

Comments
 (0)