@@ -5830,13 +5830,13 @@ static void ggml_compute_forward_add_f16_f32(
5830
5830
const int n = ggml_nrows (src0 );
5831
5831
const int nc = src0 -> ne [0 ];
5832
5832
5833
- const size_t nb00 = src0 -> nb [0 ];
5833
+ // const size_t nb00 = src0->nb[0];
5834
5834
const size_t nb01 = src0 -> nb [1 ];
5835
5835
5836
5836
const size_t nb10 = src1 -> nb [0 ];
5837
5837
const size_t nb11 = src1 -> nb [1 ];
5838
5838
5839
- const size_t nb0 = dst -> nb [0 ];
5839
+ // const size_t nb0 = dst->nb[0];
5840
5840
const size_t nb1 = dst -> nb [1 ];
5841
5841
5842
5842
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
@@ -5848,12 +5848,163 @@ static void ggml_compute_forward_add_f16_f32(
5848
5848
ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5849
5849
for (int i = 0 ; i < nc ; i ++ ) {
5850
5850
float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5851
-
5852
5851
dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5853
5852
}
5854
5853
}
5855
5854
}
5856
5855
5856
+ static void ggml_compute_forward_add_f16_f16 (
5857
+ const struct ggml_compute_params * params ,
5858
+ const struct ggml_tensor * src0 ,
5859
+ const struct ggml_tensor * src1 ,
5860
+ struct ggml_tensor * dst ) {
5861
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5862
+
5863
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5864
+ return ;
5865
+ }
5866
+
5867
+ const int ith = params -> ith ;
5868
+ const int nth = params -> nth ;
5869
+
5870
+ const int n = ggml_nrows (src0 );
5871
+ const int nc = src0 -> ne [0 ];
5872
+
5873
+ //const size_t nb00 = src0->nb[0];
5874
+ const size_t nb01 = src0 -> nb [1 ];
5875
+
5876
+ const size_t nb10 = src1 -> nb [0 ];
5877
+ const size_t nb11 = src1 -> nb [1 ];
5878
+
5879
+ //const size_t nb0 = dst->nb[0];
5880
+ const size_t nb1 = dst -> nb [1 ];
5881
+
5882
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5883
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5884
+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5885
+
5886
+ for (int j = ith ; j < n ; j += nth ) {
5887
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5888
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5889
+ for (int i = 0 ; i < nc ; i ++ ) {
5890
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5891
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5892
+ }
5893
+ }
5894
+ }
5895
+
5896
+ static void ggml_compute_forward_add_q_f32 (
5897
+ const struct ggml_compute_params * params ,
5898
+ const struct ggml_tensor * src0 ,
5899
+ const struct ggml_tensor * src1 ,
5900
+ struct ggml_tensor * dst ) {
5901
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5902
+
5903
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5904
+ return ;
5905
+ }
5906
+
5907
+ const int64_t ne00 = src0 -> ne [0 ];
5908
+ const int64_t ne01 = src0 -> ne [1 ];
5909
+ const int64_t ne02 = src0 -> ne [2 ];
5910
+ const int64_t ne03 = src0 -> ne [3 ];
5911
+
5912
+ //const int64_t ne10 = src1->ne[0];
5913
+ const int64_t ne11 = src1 -> ne [1 ];
5914
+ const int64_t ne12 = src1 -> ne [2 ];
5915
+ const int64_t ne13 = src1 -> ne [3 ];
5916
+
5917
+ const int64_t ne0 = dst -> ne [0 ];
5918
+ const int64_t ne1 = dst -> ne [1 ];
5919
+ const int64_t ne2 = dst -> ne [2 ];
5920
+ const int64_t ne3 = dst -> ne [3 ];
5921
+
5922
+ const int nb00 = src0 -> nb [0 ];
5923
+ const int nb01 = src0 -> nb [1 ];
5924
+ const int nb02 = src0 -> nb [2 ];
5925
+ const int nb03 = src0 -> nb [3 ];
5926
+
5927
+ const int nb10 = src1 -> nb [0 ];
5928
+ const int nb11 = src1 -> nb [1 ];
5929
+ const int nb12 = src1 -> nb [2 ];
5930
+ const int nb13 = src1 -> nb [3 ];
5931
+
5932
+ const int nb0 = dst -> nb [0 ];
5933
+ const int nb1 = dst -> nb [1 ];
5934
+ const int nb2 = dst -> nb [2 ];
5935
+ const int nb3 = dst -> nb [3 ];
5936
+
5937
+ const int ith = params -> ith ;
5938
+ const int nth = params -> nth ;
5939
+
5940
+ GGML_ASSERT (ne02 == ne12 );
5941
+ GGML_ASSERT (ne03 == ne13 );
5942
+ GGML_ASSERT (ne2 == ne12 );
5943
+ GGML_ASSERT (ne3 == ne13 );
5944
+
5945
+ const enum ggml_type type = src0 -> type ;
5946
+ dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
5947
+ quantize_row_q_t const quantize_row_q = quantize_fns [type ].quantize_row_q ;
5948
+
5949
+ // we don't support permuted src0 or src1
5950
+ GGML_ASSERT (nb00 == (int ) GGML_TYPE_SIZE [type ]);
5951
+ GGML_ASSERT (nb10 == sizeof (float ));
5952
+
5953
+ // dst cannot be transposed or permuted
5954
+ GGML_ASSERT (nb0 <= nb1 );
5955
+ GGML_ASSERT (nb1 <= nb2 );
5956
+ GGML_ASSERT (nb2 <= nb3 );
5957
+
5958
+ GGML_ASSERT (ne0 == ne01 );
5959
+ GGML_ASSERT (ne1 == ne11 );
5960
+ GGML_ASSERT (ne2 == ne02 );
5961
+ GGML_ASSERT (ne3 == ne03 );
5962
+
5963
+ GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 );
5964
+ GGML_ASSERT (dst -> type == src0 -> type );
5965
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5966
+
5967
+ // total rows in src0
5968
+ const int nr = ne01 * ne02 * ne03 ;
5969
+
5970
+ // rows per thread
5971
+ const int dr = (nr + nth - 1 )/nth ;
5972
+
5973
+ // row range for this thread
5974
+ const int ir0 = dr * ith ;
5975
+ const int ir1 = MIN (ir0 + dr , nr );
5976
+
5977
+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5978
+ // src0 indices
5979
+ const int i03 = ir /(ne02 * ne01 );
5980
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5981
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5982
+
5983
+ // src1 and dst are same shape as src0 => same indices
5984
+ const int i13 = i03 ;
5985
+ const int i12 = i02 ;
5986
+ const int i11 = i01 ;
5987
+
5988
+ const int i3 = i03 ;
5989
+ const int i2 = i02 ;
5990
+ const int i1 = i01 ;
5991
+
5992
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
5993
+ float * src1_row = (float * )((char * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ));
5994
+ void * dst_row = (void * ) ((char * ) dst -> data + ( i1 * nb1 + i2 * nb2 + i3 * nb0 ));
5995
+
5996
+ assert (ne00 % 32 == 0 );
5997
+
5998
+ // unquantize row from src0 to temp buffer
5999
+ float tmp [ne00 ];
6000
+ dequantize_row_q (src0_row , tmp , ne00 );
6001
+ // add src1
6002
+ ggml_vec_acc_f32 (ne00 , tmp , src1_row );
6003
+ // quantize row to dst
6004
+ quantize_row_q (tmp , dst_row , ne00 );
6005
+ }
6006
+ }
6007
+
5857
6008
static void ggml_compute_forward_add (
5858
6009
const struct ggml_compute_params * params ,
5859
6010
const struct ggml_tensor * src0 ,
@@ -5866,7 +6017,20 @@ static void ggml_compute_forward_add(
5866
6017
} break ;
5867
6018
case GGML_TYPE_F16 :
5868
6019
{
5869
- ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
6020
+ if (src1 -> type == GGML_TYPE_F16 ) {
6021
+ ggml_compute_forward_add_f16_f16 (params , src0 , src1 , dst );
6022
+ }
6023
+ else if (src1 -> type == GGML_TYPE_F32 ) {
6024
+ ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
6025
+ }
6026
+ else {
6027
+ GGML_ASSERT (false);
6028
+ }
6029
+ } break ;
6030
+ case GGML_TYPE_Q4_0 :
6031
+ case GGML_TYPE_Q4_1 :
6032
+ {
6033
+ ggml_compute_forward_add_q_f32 (params , src0 , src1 , dst );
5870
6034
} break ;
5871
6035
default :
5872
6036
{
0 commit comments