@@ -2318,6 +2318,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2318
2318
* s = sumf ;
2319
2319
}
2320
2320
2321
+ // TODO: move this to a more sensible place
2322
+ static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
2323
+ [GGML_TYPE_Q4_0 ] = {
2324
+ .dequantize_row_q = dequantize_row_q4_0 ,
2325
+ .quantize_row_q = quantize_row_q4_0 ,
2326
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
2327
+ .vec_dot_q = ggml_vec_dot_q4_0 ,
2328
+ },
2329
+ [GGML_TYPE_Q4_1 ] = {
2330
+ .dequantize_row_q = dequantize_row_q4_1 ,
2331
+ .quantize_row_q = quantize_row_q4_1 ,
2332
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
2333
+ .vec_dot_q = ggml_vec_dot_q4_1 ,
2334
+ },
2335
+ };
2336
+
2337
+ // For internal test use
2338
+ quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
2339
+ GGML_ASSERT (i < GGML_TYPE_COUNT );
2340
+ return quantize_fns [i ];
2341
+ }
2342
+
2321
2343
// compute GGML_VEC_DOT_UNROLL dot products at once
2322
2344
// xs - x row stride in bytes
2323
2345
inline static void ggml_vec_dot_f16_unroll (const int n , const int xs , float * restrict s , void * restrict xv , ggml_fp16_t * restrict y ) {
@@ -5315,13 +5337,13 @@ static void ggml_compute_forward_add_f16_f32(
5315
5337
const int n = ggml_nrows (src0 );
5316
5338
const int nc = src0 -> ne [0 ];
5317
5339
5318
- const size_t nb00 = src0 -> nb [0 ];
5340
+ // const size_t nb00 = src0->nb[0];
5319
5341
const size_t nb01 = src0 -> nb [1 ];
5320
5342
5321
5343
const size_t nb10 = src1 -> nb [0 ];
5322
5344
const size_t nb11 = src1 -> nb [1 ];
5323
5345
5324
- const size_t nb0 = dst -> nb [0 ];
5346
+ // const size_t nb0 = dst->nb[0];
5325
5347
const size_t nb1 = dst -> nb [1 ];
5326
5348
5327
5349
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
@@ -5333,12 +5355,163 @@ static void ggml_compute_forward_add_f16_f32(
5333
5355
ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5334
5356
for (int i = 0 ; i < nc ; i ++ ) {
5335
5357
float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5336
-
5337
5358
dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5338
5359
}
5339
5360
}
5340
5361
}
5341
5362
5363
+ static void ggml_compute_forward_add_f16_f16 (
5364
+ const struct ggml_compute_params * params ,
5365
+ const struct ggml_tensor * src0 ,
5366
+ const struct ggml_tensor * src1 ,
5367
+ struct ggml_tensor * dst ) {
5368
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5369
+
5370
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5371
+ return ;
5372
+ }
5373
+
5374
+ const int ith = params -> ith ;
5375
+ const int nth = params -> nth ;
5376
+
5377
+ const int n = ggml_nrows (src0 );
5378
+ const int nc = src0 -> ne [0 ];
5379
+
5380
+ //const size_t nb00 = src0->nb[0];
5381
+ const size_t nb01 = src0 -> nb [1 ];
5382
+
5383
+ const size_t nb10 = src1 -> nb [0 ];
5384
+ const size_t nb11 = src1 -> nb [1 ];
5385
+
5386
+ //const size_t nb0 = dst->nb[0];
5387
+ const size_t nb1 = dst -> nb [1 ];
5388
+
5389
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5390
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5391
+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5392
+
5393
+ for (int j = ith ; j < n ; j += nth ) {
5394
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5395
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5396
+ for (int i = 0 ; i < nc ; i ++ ) {
5397
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5398
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5399
+ }
5400
+ }
5401
+ }
5402
+
5403
+ static void ggml_compute_forward_add_q_f32 (
5404
+ const struct ggml_compute_params * params ,
5405
+ const struct ggml_tensor * src0 ,
5406
+ const struct ggml_tensor * src1 ,
5407
+ struct ggml_tensor * dst ) {
5408
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5409
+
5410
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5411
+ return ;
5412
+ }
5413
+
5414
+ const int64_t ne00 = src0 -> ne [0 ];
5415
+ const int64_t ne01 = src0 -> ne [1 ];
5416
+ const int64_t ne02 = src0 -> ne [2 ];
5417
+ const int64_t ne03 = src0 -> ne [3 ];
5418
+
5419
+ //const int64_t ne10 = src1->ne[0];
5420
+ const int64_t ne11 = src1 -> ne [1 ];
5421
+ const int64_t ne12 = src1 -> ne [2 ];
5422
+ const int64_t ne13 = src1 -> ne [3 ];
5423
+
5424
+ const int64_t ne0 = dst -> ne [0 ];
5425
+ const int64_t ne1 = dst -> ne [1 ];
5426
+ const int64_t ne2 = dst -> ne [2 ];
5427
+ const int64_t ne3 = dst -> ne [3 ];
5428
+
5429
+ const int nb00 = src0 -> nb [0 ];
5430
+ const int nb01 = src0 -> nb [1 ];
5431
+ const int nb02 = src0 -> nb [2 ];
5432
+ const int nb03 = src0 -> nb [3 ];
5433
+
5434
+ const int nb10 = src1 -> nb [0 ];
5435
+ const int nb11 = src1 -> nb [1 ];
5436
+ const int nb12 = src1 -> nb [2 ];
5437
+ const int nb13 = src1 -> nb [3 ];
5438
+
5439
+ const int nb0 = dst -> nb [0 ];
5440
+ const int nb1 = dst -> nb [1 ];
5441
+ const int nb2 = dst -> nb [2 ];
5442
+ const int nb3 = dst -> nb [3 ];
5443
+
5444
+ const int ith = params -> ith ;
5445
+ const int nth = params -> nth ;
5446
+
5447
+ GGML_ASSERT (ne02 == ne12 );
5448
+ GGML_ASSERT (ne03 == ne13 );
5449
+ GGML_ASSERT (ne2 == ne12 );
5450
+ GGML_ASSERT (ne3 == ne13 );
5451
+
5452
+ const enum ggml_type type = src0 -> type ;
5453
+ dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
5454
+ quantize_row_q_t const quantize_row_q = quantize_fns [type ].quantize_row_q ;
5455
+
5456
+ // we don't support permuted src0 or src1
5457
+ GGML_ASSERT (nb00 == (int ) GGML_TYPE_SIZE [type ]);
5458
+ GGML_ASSERT (nb10 == sizeof (float ));
5459
+
5460
+ // dst cannot be transposed or permuted
5461
+ GGML_ASSERT (nb0 <= nb1 );
5462
+ GGML_ASSERT (nb1 <= nb2 );
5463
+ GGML_ASSERT (nb2 <= nb3 );
5464
+
5465
+ GGML_ASSERT (ne0 == ne01 );
5466
+ GGML_ASSERT (ne1 == ne11 );
5467
+ GGML_ASSERT (ne2 == ne02 );
5468
+ GGML_ASSERT (ne3 == ne03 );
5469
+
5470
+ GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 );
5471
+ GGML_ASSERT (dst -> type == src0 -> type );
5472
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5473
+
5474
+ // total rows in src0
5475
+ const int nr = ne01 * ne02 * ne03 ;
5476
+
5477
+ // rows per thread
5478
+ const int dr = (nr + nth - 1 )/nth ;
5479
+
5480
+ // row range for this thread
5481
+ const int ir0 = dr * ith ;
5482
+ const int ir1 = MIN (ir0 + dr , nr );
5483
+
5484
+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5485
+ // src0 indices
5486
+ const int i03 = ir /(ne02 * ne01 );
5487
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5488
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5489
+
5490
+ // src1 and dst are same shape as src0 => same indices
5491
+ const int i13 = i03 ;
5492
+ const int i12 = i02 ;
5493
+ const int i11 = i01 ;
5494
+
5495
+ const int i3 = i03 ;
5496
+ const int i2 = i02 ;
5497
+ const int i1 = i01 ;
5498
+
5499
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
5500
+ float * src1_row = (float * )((char * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ));
5501
+ void * dst_row = (void * ) ((char * ) dst -> data + ( i1 * nb1 + i2 * nb2 + i3 * nb0 ));
5502
+
5503
+ assert (ne00 % 32 == 0 );
5504
+
5505
+ // unquantize row from src0 to temp buffer
5506
+ float tmp [ne00 ];
5507
+ dequantize_row_q (src0_row , tmp , ne00 );
5508
+ // add src1
5509
+ ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5510
+ // quantize row to dst
5511
+ quantize_row_q (tmp , dst_row , ne00 );
5512
+ }
5513
+ }
5514
+
5342
5515
static void ggml_compute_forward_add (
5343
5516
const struct ggml_compute_params * params ,
5344
5517
const struct ggml_tensor * src0 ,
@@ -5351,10 +5524,21 @@ static void ggml_compute_forward_add(
5351
5524
} break ;
5352
5525
case GGML_TYPE_F16 :
5353
5526
{
5354
- ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5527
+ if (src1 -> type == GGML_TYPE_F16 ) {
5528
+ ggml_compute_forward_add_f16_f16 (params , src0 , src1 , dst );
5529
+ }
5530
+ else if (src1 -> type == GGML_TYPE_F32 ) {
5531
+ ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5532
+ }
5533
+ else {
5534
+ GGML_ASSERT (false);
5535
+ }
5355
5536
} break ;
5356
5537
case GGML_TYPE_Q4_0 :
5357
5538
case GGML_TYPE_Q4_1 :
5539
+ {
5540
+ ggml_compute_forward_add_q_f32 (params , src0 , src1 , dst );
5541
+ } break ;
5358
5542
case GGML_TYPE_I8 :
5359
5543
case GGML_TYPE_I16 :
5360
5544
case GGML_TYPE_I32 :
@@ -6739,27 +6923,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6739
6923
//}
6740
6924
}
6741
6925
6742
- static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
6743
- [GGML_TYPE_Q4_0 ] = {
6744
- .dequantize_row_q = dequantize_row_q4_0 ,
6745
- .quantize_row_q = quantize_row_q4_0 ,
6746
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
6747
- .vec_dot_q = ggml_vec_dot_q4_0 ,
6748
- },
6749
- [GGML_TYPE_Q4_1 ] = {
6750
- .dequantize_row_q = dequantize_row_q4_1 ,
6751
- .quantize_row_q = quantize_row_q4_1 ,
6752
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
6753
- .vec_dot_q = ggml_vec_dot_q4_1 ,
6754
- },
6755
- };
6756
-
6757
- // For internal test use
6758
- quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
6759
- GGML_ASSERT (i < GGML_TYPE_COUNT );
6760
- return quantize_fns [i ];
6761
- }
6762
-
6763
6926
static void ggml_compute_forward_mul_mat_q_f32 (
6764
6927
const struct ggml_compute_params * params ,
6765
6928
const struct ggml_tensor * src0 ,
0 commit comments