Skip to content

Commit 1b51c73

Browse files
committed
Add support for quantized models
1 parent c47c88f commit 1b51c73

File tree

2 files changed

+199
-27
lines changed

2 files changed

+199
-27
lines changed

ggml.c

Lines changed: 188 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
23182318
*s = sumf;
23192319
}
23202320

2321+
// TODO: move this to a more sensible place
2322+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
2323+
[GGML_TYPE_Q4_0] = {
2324+
.dequantize_row_q = dequantize_row_q4_0,
2325+
.quantize_row_q = quantize_row_q4_0,
2326+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
2327+
.vec_dot_q = ggml_vec_dot_q4_0,
2328+
},
2329+
[GGML_TYPE_Q4_1] = {
2330+
.dequantize_row_q = dequantize_row_q4_1,
2331+
.quantize_row_q = quantize_row_q4_1,
2332+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
2333+
.vec_dot_q = ggml_vec_dot_q4_1,
2334+
},
2335+
};
2336+
2337+
// For internal test use
2338+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
2339+
GGML_ASSERT(i < GGML_TYPE_COUNT);
2340+
return quantize_fns[i];
2341+
}
2342+
23212343
// compute GGML_VEC_DOT_UNROLL dot products at once
23222344
// xs - x row stride in bytes
23232345
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -5315,13 +5337,13 @@ static void ggml_compute_forward_add_f16_f32(
53155337
const int n = ggml_nrows(src0);
53165338
const int nc = src0->ne[0];
53175339

5318-
const size_t nb00 = src0->nb[0];
5340+
//const size_t nb00 = src0->nb[0];
53195341
const size_t nb01 = src0->nb[1];
53205342

53215343
const size_t nb10 = src1->nb[0];
53225344
const size_t nb11 = src1->nb[1];
53235345

5324-
const size_t nb0 = dst->nb[0];
5346+
//const size_t nb0 = dst->nb[0];
53255347
const size_t nb1 = dst->nb[1];
53265348

53275349
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -5333,12 +5355,163 @@ static void ggml_compute_forward_add_f16_f32(
53335355
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
53345356
for (int i = 0; i < nc; i++) {
53355357
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5336-
53375358
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
53385359
}
53395360
}
53405361
}
53415362

5363+
static void ggml_compute_forward_add_f16_f16(
5364+
const struct ggml_compute_params * params,
5365+
const struct ggml_tensor * src0,
5366+
const struct ggml_tensor * src1,
5367+
struct ggml_tensor * dst) {
5368+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5369+
5370+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5371+
return;
5372+
}
5373+
5374+
const int ith = params->ith;
5375+
const int nth = params->nth;
5376+
5377+
const int n = ggml_nrows(src0);
5378+
const int nc = src0->ne[0];
5379+
5380+
//const size_t nb00 = src0->nb[0];
5381+
const size_t nb01 = src0->nb[1];
5382+
5383+
const size_t nb10 = src1->nb[0];
5384+
const size_t nb11 = src1->nb[1];
5385+
5386+
//const size_t nb0 = dst->nb[0];
5387+
const size_t nb1 = dst->nb[1];
5388+
5389+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5390+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5391+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5392+
5393+
for (int j = ith; j < n; j += nth) {
5394+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5395+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5396+
for (int i = 0; i < nc; i++) {
5397+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5398+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5399+
}
5400+
}
5401+
}
5402+
5403+
static void ggml_compute_forward_add_q_f32(
5404+
const struct ggml_compute_params * params,
5405+
const struct ggml_tensor * src0,
5406+
const struct ggml_tensor * src1,
5407+
struct ggml_tensor * dst) {
5408+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5409+
5410+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5411+
return;
5412+
}
5413+
5414+
const int64_t ne00 = src0->ne[0];
5415+
const int64_t ne01 = src0->ne[1];
5416+
const int64_t ne02 = src0->ne[2];
5417+
const int64_t ne03 = src0->ne[3];
5418+
5419+
//const int64_t ne10 = src1->ne[0];
5420+
const int64_t ne11 = src1->ne[1];
5421+
const int64_t ne12 = src1->ne[2];
5422+
const int64_t ne13 = src1->ne[3];
5423+
5424+
const int64_t ne0 = dst->ne[0];
5425+
const int64_t ne1 = dst->ne[1];
5426+
const int64_t ne2 = dst->ne[2];
5427+
const int64_t ne3 = dst->ne[3];
5428+
5429+
const int nb00 = src0->nb[0];
5430+
const int nb01 = src0->nb[1];
5431+
const int nb02 = src0->nb[2];
5432+
const int nb03 = src0->nb[3];
5433+
5434+
const int nb10 = src1->nb[0];
5435+
const int nb11 = src1->nb[1];
5436+
const int nb12 = src1->nb[2];
5437+
const int nb13 = src1->nb[3];
5438+
5439+
const int nb0 = dst->nb[0];
5440+
const int nb1 = dst->nb[1];
5441+
const int nb2 = dst->nb[2];
5442+
const int nb3 = dst->nb[3];
5443+
5444+
const int ith = params->ith;
5445+
const int nth = params->nth;
5446+
5447+
GGML_ASSERT(ne02 == ne12);
5448+
GGML_ASSERT(ne03 == ne13);
5449+
GGML_ASSERT(ne2 == ne12);
5450+
GGML_ASSERT(ne3 == ne13);
5451+
5452+
const enum ggml_type type = src0->type;
5453+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
5454+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
5455+
5456+
// we don't support permuted src0 or src1
5457+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
5458+
GGML_ASSERT(nb10 == sizeof(float));
5459+
5460+
// dst cannot be transposed or permuted
5461+
GGML_ASSERT(nb0 <= nb1);
5462+
GGML_ASSERT(nb1 <= nb2);
5463+
GGML_ASSERT(nb2 <= nb3);
5464+
5465+
GGML_ASSERT(ne0 == ne01);
5466+
GGML_ASSERT(ne1 == ne11);
5467+
GGML_ASSERT(ne2 == ne02);
5468+
GGML_ASSERT(ne3 == ne03);
5469+
5470+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
5471+
GGML_ASSERT(dst->type == src0->type);
5472+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5473+
5474+
// total rows in src0
5475+
const int nr = ne01*ne02*ne03;
5476+
5477+
// rows per thread
5478+
const int dr = (nr + nth - 1)/nth;
5479+
5480+
// row range for this thread
5481+
const int ir0 = dr*ith;
5482+
const int ir1 = MIN(ir0 + dr, nr);
5483+
5484+
for (int ir = ir0; ir < ir1; ++ir) {
5485+
// src0 indices
5486+
const int i03 = ir/(ne02*ne01);
5487+
const int i02 = (ir - i03*ne02*ne01)/ne01;
5488+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5489+
5490+
// src1 and dst are same shape as src0 => same indices
5491+
const int i13 = i03;
5492+
const int i12 = i02;
5493+
const int i11 = i01;
5494+
5495+
const int i3 = i03;
5496+
const int i2 = i02;
5497+
const int i1 = i01;
5498+
5499+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
5500+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
5501+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5502+
5503+
assert(ne00 % 32 == 0);
5504+
5505+
// unquantize row from src0 to temp buffer
5506+
float tmp[ne00];
5507+
dequantize_row_q(src0_row, tmp, ne00);
5508+
// add src1
5509+
ggml_vec_acc_f32(ne00, tmp, src1_row);
5510+
// quantize row to dst
5511+
quantize_row_q(tmp, dst_row, ne00);
5512+
}
5513+
}
5514+
53425515
static void ggml_compute_forward_add(
53435516
const struct ggml_compute_params * params,
53445517
const struct ggml_tensor * src0,
@@ -5351,10 +5524,21 @@ static void ggml_compute_forward_add(
53515524
} break;
53525525
case GGML_TYPE_F16:
53535526
{
5354-
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5527+
if (src1->type == GGML_TYPE_F16) {
5528+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
5529+
}
5530+
else if (src1->type == GGML_TYPE_F32) {
5531+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5532+
}
5533+
else {
5534+
GGML_ASSERT(false);
5535+
}
53555536
} break;
53565537
case GGML_TYPE_Q4_0:
53575538
case GGML_TYPE_Q4_1:
5539+
{
5540+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
5541+
} break;
53585542
case GGML_TYPE_I8:
53595543
case GGML_TYPE_I16:
53605544
case GGML_TYPE_I32:
@@ -6739,27 +6923,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
67396923
//}
67406924
}
67416925

6742-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6743-
[GGML_TYPE_Q4_0] = {
6744-
.dequantize_row_q = dequantize_row_q4_0,
6745-
.quantize_row_q = quantize_row_q4_0,
6746-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6747-
.vec_dot_q = ggml_vec_dot_q4_0,
6748-
},
6749-
[GGML_TYPE_Q4_1] = {
6750-
.dequantize_row_q = dequantize_row_q4_1,
6751-
.quantize_row_q = quantize_row_q4_1,
6752-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6753-
.vec_dot_q = ggml_vec_dot_q4_1,
6754-
},
6755-
};
6756-
6757-
// For internal test use
6758-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6759-
GGML_ASSERT(i < GGML_TYPE_COUNT);
6760-
return quantize_fns[i];
6761-
}
6762-
67636926
static void ggml_compute_forward_mul_mat_q_f32(
67646927
const struct ggml_compute_params * params,
67656928
const struct ggml_tensor * src0,

llama.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,14 +1895,23 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18951895
return 1;
18961896
}
18971897

1898-
// w = w + BA
1898+
// w = w + BA*s
18991899
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
1900-
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1900+
1901+
//if (true) {
1902+
// ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f);
1903+
// BA = ggml_scale(lora_ctx, BA, scale_tensor);
1904+
//}
1905+
ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
1906+
//r = ggml_cpy(lora_ctx, r, tensor);
19011907

19021908
struct ggml_cgraph gf = ggml_build_forward(r);
19031909
gf.n_threads = n_threads;
19041910
ggml_graph_compute(lora_ctx, &gf);
19051911

1912+
// hack until ggml_cpy supports quantized tensors
1913+
memcpy(tensor->data, r->data, ggml_nbytes(tensor));
1914+
19061915
// we won't need these tensors again, reset the context to save memory
19071916
ggml_free(lora_ctx);
19081917
lora_ctx = ggml_init(params);

0 commit comments

Comments
 (0)