Skip to content

Commit bc06756

Browse files
ggerganovMinh141120
authored andcommitted
metal : add mean kernel (ggml-org#14267)
* metal : add mean kernel ggml-ci * cont : dedup implementation ggml-ci
1 parent ae9ee21 commit bc06756

File tree

2 files changed

+44
-74
lines changed

2 files changed

+44
-74
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
531531
GGML_METAL_KERNEL_TYPE_SWIGLU,
532532
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
533533
GGML_METAL_KERNEL_TYPE_MEAN,
534+
GGML_METAL_KERNEL_TYPE_MEAN,
534535
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
535536
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
536537
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -2570,7 +2571,6 @@ static bool ggml_metal_encode_node(
25702571
nth *= 2;
25712572
}
25722573

2573-
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
25742574
nth = MIN(nth, ne00);
25752575

25762576
ggml_metal_kargs_sum_rows args = {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,70 +1194,6 @@ kernel void kernel_neg(
11941194
dst[tpig] = -src0[tpig];
11951195
}
11961196

1197-
kernel void kernel_reglu(
1198-
device const char * src0,
1199-
device const char * src1,
1200-
device char * dst,
1201-
constant ggml_metal_kargs_glu & args,
1202-
uint tgpig[[threadgroup_position_in_grid]],
1203-
uint tpitg[[thread_position_in_threadgroup]],
1204-
uint ntg[[threads_per_threadgroup]]) {
1205-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1206-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1207-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1208-
1209-
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1210-
const float x0 = src0_row[i0];
1211-
const float x1 = src1_row[i0];
1212-
1213-
dst_row[i0] = x0*x1*(x0 > 0.0f);
1214-
}
1215-
}
1216-
1217-
kernel void kernel_geglu(
1218-
device const char * src0,
1219-
device const char * src1,
1220-
device char * dst,
1221-
constant ggml_metal_kargs_glu & args,
1222-
uint tgpig[[threadgroup_position_in_grid]],
1223-
uint tpitg[[thread_position_in_threadgroup]],
1224-
uint ntg[[threads_per_threadgroup]]) {
1225-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1226-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1227-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1228-
1229-
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1230-
const float x0 = src0_row[i0];
1231-
const float x1 = src1_row[i0];
1232-
1233-
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1234-
1235-
dst_row[i0] = gelu*x1;
1236-
}
1237-
}
1238-
1239-
kernel void kernel_swiglu(
1240-
device const char * src0,
1241-
device const char * src1,
1242-
device char * dst,
1243-
constant ggml_metal_kargs_glu & args,
1244-
uint tgpig[[threadgroup_position_in_grid]],
1245-
uint tpitg[[thread_position_in_threadgroup]],
1246-
uint ntg[[threads_per_threadgroup]]) {
1247-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1248-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1249-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1250-
1251-
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1252-
const float x0 = src0_row[i0];
1253-
const float x1 = src1_row[i0];
1254-
1255-
const float silu = x0 / (1.0f + exp(-x0));
1256-
1257-
dst_row[i0] = silu*x1;
1258-
}
1259-
}
1260-
12611197
template <bool norm>
12621198
kernel void kernel_sum_rows(
12631199
constant ggml_metal_kargs_sum_rows & args,
@@ -1298,14 +1234,7 @@ kernel void kernel_sum_rows(
12981234
shmem_f32[sgitg] = sumf;
12991235
}
13001236

1301-
threadgroup_barrier(mem_flags::mem_threadgroup);
1302-
1303-
sumf = shmem_f32[tiisg];
1304-
sumf = simd_sum(sumf);
1305-
1306-
if (tpitg.x == 0) {
1307-
dst_row[0] = norm ? sumf / args.ne00 : sumf;
1308-
}
1237+
dst_row[0] = row_sum;
13091238
}
13101239

13111240
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
@@ -4807,10 +4736,51 @@ kernel void kernel_cpy_f32_q5_1(
48074736
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
48084737
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
48094738

4810-
quantize_q5_1(src, dst_data[i00/QK5_1]);
4739+
float max = src[0];
4740+
float min = src[0];
4741+
4742+
for (int j = 1; j < QK5_1; j++) {
4743+
const float v = src[j];
4744+
min = v < min ? v : min;
4745+
max = v > max ? v : max;
4746+
}
4747+
4748+
const float d = (max - min) / 31;
4749+
const float id = d ? 1.0f/d : 0.0f;
4750+
4751+
dst_data[i00/QK5_1].d = d;
4752+
dst_data[i00/QK5_1].m = min;
4753+
4754+
uint32_t qh = 0;
4755+
for (int j = 0; j < QK5_1/2; ++j) {
4756+
const float x0 = (src[0 + j] - min)*id;
4757+
const float x1 = (src[QK5_1/2 + j] - min)*id;
4758+
4759+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4760+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4761+
4762+
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4763+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4764+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4765+
}
4766+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4767+
for (int j = 0; j < 4; ++j) {
4768+
dst_data[i00/QK5_1].qh[j] = qh8[j];
4769+
}
48114770
}
48124771
}
48134772

4773+
static inline int best_index_int8(int n, constant float * val, float x) {
4774+
if (x <= val[0]) return 0;
4775+
if (x >= val[n-1]) return n-1;
4776+
int ml = 0, mu = n-1;
4777+
while (mu-ml > 1) {
4778+
int mav = (ml+mu)/2;
4779+
if (x < val[mav]) mu = mav; else ml = mav;
4780+
}
4781+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4782+
}
4783+
48144784
kernel void kernel_cpy_f32_iq4_nl(
48154785
constant ggml_metal_kargs_cpy & args,
48164786
device const char * src0,

0 commit comments

Comments
 (0)