Skip to content

Commit 7136ada

Browse files
committed
Add support for quantized models
1 parent ac3fbe4 commit 7136ada

File tree

2 files changed

+179
-6
lines changed

2 files changed

+179
-6
lines changed

ggml.c

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5830,13 +5830,13 @@ static void ggml_compute_forward_add_f16_f32(
58305830
const int n = ggml_nrows(src0);
58315831
const int nc = src0->ne[0];
58325832

5833-
const size_t nb00 = src0->nb[0];
5833+
//const size_t nb00 = src0->nb[0];
58345834
const size_t nb01 = src0->nb[1];
58355835

58365836
const size_t nb10 = src1->nb[0];
58375837
const size_t nb11 = src1->nb[1];
58385838

5839-
const size_t nb0 = dst->nb[0];
5839+
//const size_t nb0 = dst->nb[0];
58405840
const size_t nb1 = dst->nb[1];
58415841

58425842
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -5848,12 +5848,163 @@ static void ggml_compute_forward_add_f16_f32(
58485848
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
58495849
for (int i = 0; i < nc; i++) {
58505850
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5851-
58525851
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
58535852
}
58545853
}
58555854
}
58565855

5856+
static void ggml_compute_forward_add_f16_f16(
5857+
const struct ggml_compute_params * params,
5858+
const struct ggml_tensor * src0,
5859+
const struct ggml_tensor * src1,
5860+
struct ggml_tensor * dst) {
5861+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5862+
5863+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5864+
return;
5865+
}
5866+
5867+
const int ith = params->ith;
5868+
const int nth = params->nth;
5869+
5870+
const int n = ggml_nrows(src0);
5871+
const int nc = src0->ne[0];
5872+
5873+
//const size_t nb00 = src0->nb[0];
5874+
const size_t nb01 = src0->nb[1];
5875+
5876+
const size_t nb10 = src1->nb[0];
5877+
const size_t nb11 = src1->nb[1];
5878+
5879+
//const size_t nb0 = dst->nb[0];
5880+
const size_t nb1 = dst->nb[1];
5881+
5882+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5883+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5884+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5885+
5886+
for (int j = ith; j < n; j += nth) {
5887+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5888+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5889+
for (int i = 0; i < nc; i++) {
5890+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5891+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5892+
}
5893+
}
5894+
}
5895+
5896+
static void ggml_compute_forward_add_q_f32(
5897+
const struct ggml_compute_params * params,
5898+
const struct ggml_tensor * src0,
5899+
const struct ggml_tensor * src1,
5900+
struct ggml_tensor * dst) {
5901+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5902+
5903+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5904+
return;
5905+
}
5906+
5907+
const int64_t ne00 = src0->ne[0];
5908+
const int64_t ne01 = src0->ne[1];
5909+
const int64_t ne02 = src0->ne[2];
5910+
const int64_t ne03 = src0->ne[3];
5911+
5912+
//const int64_t ne10 = src1->ne[0];
5913+
const int64_t ne11 = src1->ne[1];
5914+
const int64_t ne12 = src1->ne[2];
5915+
const int64_t ne13 = src1->ne[3];
5916+
5917+
const int64_t ne0 = dst->ne[0];
5918+
const int64_t ne1 = dst->ne[1];
5919+
const int64_t ne2 = dst->ne[2];
5920+
const int64_t ne3 = dst->ne[3];
5921+
5922+
const int nb00 = src0->nb[0];
5923+
const int nb01 = src0->nb[1];
5924+
const int nb02 = src0->nb[2];
5925+
const int nb03 = src0->nb[3];
5926+
5927+
const int nb10 = src1->nb[0];
5928+
const int nb11 = src1->nb[1];
5929+
const int nb12 = src1->nb[2];
5930+
const int nb13 = src1->nb[3];
5931+
5932+
const int nb0 = dst->nb[0];
5933+
const int nb1 = dst->nb[1];
5934+
const int nb2 = dst->nb[2];
5935+
const int nb3 = dst->nb[3];
5936+
5937+
const int ith = params->ith;
5938+
const int nth = params->nth;
5939+
5940+
GGML_ASSERT(ne02 == ne12);
5941+
GGML_ASSERT(ne03 == ne13);
5942+
GGML_ASSERT(ne2 == ne12);
5943+
GGML_ASSERT(ne3 == ne13);
5944+
5945+
const enum ggml_type type = src0->type;
5946+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
5947+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
5948+
5949+
// we don't support permuted src0 or src1
5950+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
5951+
GGML_ASSERT(nb10 == sizeof(float));
5952+
5953+
// dst cannot be transposed or permuted
5954+
GGML_ASSERT(nb0 <= nb1);
5955+
GGML_ASSERT(nb1 <= nb2);
5956+
GGML_ASSERT(nb2 <= nb3);
5957+
5958+
GGML_ASSERT(ne0 == ne01);
5959+
GGML_ASSERT(ne1 == ne11);
5960+
GGML_ASSERT(ne2 == ne02);
5961+
GGML_ASSERT(ne3 == ne03);
5962+
5963+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
5964+
GGML_ASSERT(dst->type == src0->type);
5965+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5966+
5967+
// total rows in src0
5968+
const int nr = ne01*ne02*ne03;
5969+
5970+
// rows per thread
5971+
const int dr = (nr + nth - 1)/nth;
5972+
5973+
// row range for this thread
5974+
const int ir0 = dr*ith;
5975+
const int ir1 = MIN(ir0 + dr, nr);
5976+
5977+
for (int ir = ir0; ir < ir1; ++ir) {
5978+
// src0 indices
5979+
const int i03 = ir/(ne02*ne01);
5980+
const int i02 = (ir - i03*ne02*ne01)/ne01;
5981+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5982+
5983+
// src1 and dst are same shape as src0 => same indices
5984+
const int i13 = i03;
5985+
const int i12 = i02;
5986+
const int i11 = i01;
5987+
5988+
const int i3 = i03;
5989+
const int i2 = i02;
5990+
const int i1 = i01;
5991+
5992+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
5993+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
5994+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5995+
5996+
assert(ne00 % 32 == 0);
5997+
5998+
// unquantize row from src0 to temp buffer
5999+
float tmp[ne00];
6000+
dequantize_row_q(src0_row, tmp, ne00);
6001+
// add src1
6002+
ggml_vec_acc_f32(ne00, tmp, src1_row);
6003+
// quantize row to dst
6004+
quantize_row_q(tmp, dst_row, ne00);
6005+
}
6006+
}
6007+
58576008
static void ggml_compute_forward_add(
58586009
const struct ggml_compute_params * params,
58596010
const struct ggml_tensor * src0,
@@ -5866,7 +6017,20 @@ static void ggml_compute_forward_add(
58666017
} break;
58676018
case GGML_TYPE_F16:
58686019
{
5869-
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6020+
if (src1->type == GGML_TYPE_F16) {
6021+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6022+
}
6023+
else if (src1->type == GGML_TYPE_F32) {
6024+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6025+
}
6026+
else {
6027+
GGML_ASSERT(false);
6028+
}
6029+
} break;
6030+
case GGML_TYPE_Q4_0:
6031+
case GGML_TYPE_Q4_1:
6032+
{
6033+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
58706034
} break;
58716035
default:
58726036
{

llama.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,14 +1887,23 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18871887
return 1;
18881888
}
18891889

1890-
// w = w + BA
1890+
// w = w + BA*s
18911891
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
1892-
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1892+
1893+
//if (true) {
1894+
// ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f);
1895+
// BA = ggml_scale(lora_ctx, BA, scale_tensor);
1896+
//}
1897+
ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
1898+
//r = ggml_cpy(lora_ctx, r, tensor);
18931899

18941900
struct ggml_cgraph gf = ggml_build_forward(r);
18951901
gf.n_threads = n_threads;
18961902
ggml_graph_compute(lora_ctx, &gf);
18971903

1904+
// hack until ggml_cpy supports quantized tensors
1905+
memcpy(tensor->data, r->data, ggml_nbytes(tensor));
1906+
18981907
// we won't need these tensors again, reset the context to save memory
18991908
ggml_free(lora_ctx);
19001909
lora_ctx = ggml_init(params);

0 commit comments

Comments
 (0)