@@ -1194,70 +1194,6 @@ kernel void kernel_neg(
1194
1194
dst[tpig] = -src0[tpig];
1195
1195
}
1196
1196
1197
- kernel void kernel_reglu (
1198
- device const char * src0,
1199
- device const char * src1,
1200
- device char * dst,
1201
- constant ggml_metal_kargs_glu & args,
1202
- uint tgpig[[threadgroup_position_in_grid]],
1203
- uint tpitg[[thread_position_in_threadgroup]],
1204
- uint ntg[[threads_per_threadgroup]]) {
1205
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1206
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1207
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1208
-
1209
- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1210
- const float x0 = src0_row[i0];
1211
- const float x1 = src1_row[i0];
1212
-
1213
- dst_row[i0] = x0*x1*(x0 > 0 .0f );
1214
- }
1215
- }
1216
-
1217
- kernel void kernel_geglu (
1218
- device const char * src0,
1219
- device const char * src1,
1220
- device char * dst,
1221
- constant ggml_metal_kargs_glu & args,
1222
- uint tgpig[[threadgroup_position_in_grid]],
1223
- uint tpitg[[thread_position_in_threadgroup]],
1224
- uint ntg[[threads_per_threadgroup]]) {
1225
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1226
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1227
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1228
-
1229
- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1230
- const float x0 = src0_row[i0];
1231
- const float x1 = src1_row[i0];
1232
-
1233
- const float gelu = 0 .5f *x0*(1 .0f + precise::tanh (SQRT_2_OVER_PI*x0*(1 .0f + GELU_COEF_A*x0*x0)));
1234
-
1235
- dst_row[i0] = gelu*x1;
1236
- }
1237
- }
1238
-
1239
- kernel void kernel_swiglu (
1240
- device const char * src0,
1241
- device const char * src1,
1242
- device char * dst,
1243
- constant ggml_metal_kargs_glu & args,
1244
- uint tgpig[[threadgroup_position_in_grid]],
1245
- uint tpitg[[thread_position_in_threadgroup]],
1246
- uint ntg[[threads_per_threadgroup]]) {
1247
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1248
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1249
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1250
-
1251
- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1252
- const float x0 = src0_row[i0];
1253
- const float x1 = src1_row[i0];
1254
-
1255
- const float silu = x0 / (1 .0f + exp (-x0));
1256
-
1257
- dst_row[i0] = silu*x1;
1258
- }
1259
- }
1260
-
1261
1197
template <bool norm>
1262
1198
kernel void kernel_sum_rows (
1263
1199
constant ggml_metal_kargs_sum_rows & args,
@@ -1298,14 +1234,7 @@ kernel void kernel_sum_rows(
1298
1234
shmem_f32[sgitg] = sumf;
1299
1235
}
1300
1236
1301
- threadgroup_barrier (mem_flags::mem_threadgroup);
1302
-
1303
- sumf = shmem_f32[tiisg];
1304
- sumf = simd_sum (sumf);
1305
-
1306
- if (tpitg.x == 0 ) {
1307
- dst_row[0 ] = norm ? sumf / args.ne00 : sumf;
1308
- }
1237
+ dst_row[0 ] = row_sum;
1309
1238
}
1310
1239
1311
1240
typedef decltype (kernel_sum_rows<false >) kernel_sum_rows_t;
@@ -4807,10 +4736,51 @@ kernel void kernel_cpy_f32_q5_1(
4807
4736
for (int64_t i00 = tpitg.x *QK5_1; i00 < args.ne00 ; i00 += ntg.x *QK5_1) {
4808
4737
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00 );
4809
4738
4810
- quantize_q5_1 (src, dst_data[i00/QK5_1]);
4739
+ float max = src[0 ];
4740
+ float min = src[0 ];
4741
+
4742
+ for (int j = 1 ; j < QK5_1; j++) {
4743
+ const float v = src[j];
4744
+ min = v < min ? v : min;
4745
+ max = v > max ? v : max;
4746
+ }
4747
+
4748
+ const float d = (max - min) / 31 ;
4749
+ const float id = d ? 1 .0f /d : 0 .0f ;
4750
+
4751
+ dst_data[i00/QK5_1].d = d;
4752
+ dst_data[i00/QK5_1].m = min;
4753
+
4754
+ uint32_t qh = 0 ;
4755
+ for (int j = 0 ; j < QK5_1/2 ; ++j) {
4756
+ const float x0 = (src[0 + j] - min)*id;
4757
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
4758
+
4759
+ const uint8_t xi0 = (uint8_t )(x0 + 0 .5f );
4760
+ const uint8_t xi1 = (uint8_t )(x1 + 0 .5f );
4761
+
4762
+ dst_data[i00/QK5_1].qs [j] = (xi0 & 0xf ) | ((xi1 & 0xf ) << 4 );
4763
+ qh |= ((xi0 & 0x10u ) >> 4 ) << (j + 0 );
4764
+ qh |= ((xi1 & 0x10u ) >> 4 ) << (j + QK5_1/2 );
4765
+ }
4766
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4767
+ for (int j = 0 ; j < 4 ; ++j) {
4768
+ dst_data[i00/QK5_1].qh [j] = qh8[j];
4769
+ }
4811
4770
}
4812
4771
}
4813
4772
4773
+ static inline int best_index_int8 (int n, constant float * val, float x) {
4774
+ if (x <= val[0 ]) return 0 ;
4775
+ if (x >= val[n-1 ]) return n-1 ;
4776
+ int ml = 0 , mu = n-1 ;
4777
+ while (mu-ml > 1 ) {
4778
+ int mav = (ml+mu)/2 ;
4779
+ if (x < val[mav]) mu = mav; else ml = mav;
4780
+ }
4781
+ return x - val[mu-1 ] < val[mu] - x ? mu-1 : mu;
4782
+ }
4783
+
4814
4784
kernel void kernel_cpy_f32_iq4_nl (
4815
4785
constant ggml_metal_kargs_cpy & args,
4816
4786
device const char * src0,
0 commit comments