Skip to content

Commit d9ddeb9

Browse files
ggerganovqnixsynapse
authored andcommitted
metal : add glu kernels
ggml-ci
1 parent a341aa3 commit d9ddeb9

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,12 @@ typedef struct {
422422
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423423
} ggml_metal_kargs_im2col;
424424

425+
typedef struct{
426+
int32_t ne00;
427+
uint64_t nb01;
428+
uint64_t nb1;
429+
} ggml_metal_kargs_glu;
430+
425431
typedef struct {
426432
int64_t ne00;
427433
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
514514
GGML_METAL_KERNEL_TYPE_SIN,
515515
GGML_METAL_KERNEL_TYPE_COS,
516516
GGML_METAL_KERNEL_TYPE_NEG,
517+
GGML_METAL_KERNEL_TYPE_REGLU,
518+
GGML_METAL_KERNEL_TYPE_GEGLU,
519+
GGML_METAL_KERNEL_TYPE_SWIGLU,
517520
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
518521
GGML_METAL_KERNEL_TYPE_MEAN,
519522
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1478,6 +1481,9 @@ @implementation GGMLMetalClass
14781481
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
14791482
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
14801483
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1484+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1485+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1486+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
14811487
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
14821488
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
14831489
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1652,6 +1658,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16521658
default:
16531659
return false;
16541660
}
1661+
case GGML_OP_GLU:
1662+
switch (ggml_get_glu_op(op)) {
1663+
case GGML_GLU_OP_REGLU:
1664+
case GGML_GLU_OP_GEGLU:
1665+
case GGML_GLU_OP_SWIGLU:
1666+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1667+
default:
1668+
return false;
1669+
}
16551670
case GGML_OP_NONE:
16561671
case GGML_OP_RESHAPE:
16571672
case GGML_OP_VIEW:
@@ -2370,6 +2385,43 @@ static bool ggml_metal_encode_node(
23702385
GGML_ABORT("fatal error");
23712386
}
23722387
} break;
2388+
case GGML_OP_GLU:
2389+
{
2390+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2391+
2392+
id<MTLComputePipelineState> pipeline = nil;
2393+
2394+
switch (ggml_get_glu_op(node)) {
2395+
case GGML_GLU_OP_REGLU:
2396+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2397+
break;
2398+
case GGML_GLU_OP_GEGLU:
2399+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2400+
break;
2401+
case GGML_GLU_OP_SWIGLU:
2402+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2403+
break;
2404+
default:
2405+
GGML_ABORT("fatal error");
2406+
}
2407+
2408+
ggml_metal_kargs_glu args = {
2409+
/*.ne00 =*/ ne00,
2410+
/*.nb01 =*/ nb01,
2411+
/*.nb1 =*/ nb1,
2412+
};
2413+
2414+
[encoder setComputePipelineState:pipeline];
2415+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2416+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2417+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
2418+
2419+
const int64_t nrows = ggml_nrows(src0);
2420+
2421+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2422+
2423+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2424+
} break;
23732425
case GGML_OP_SQR:
23742426
{
23752427
GGML_ASSERT(ggml_is_contiguous(src0));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,64 @@ kernel void kernel_neg(
993993
dst[tpig] = -src0[tpig];
994994
}
995995

996+
kernel void kernel_reglu(
997+
device const char * src0,
998+
device char * dst,
999+
constant ggml_metal_kargs_glu & args,
1000+
uint tgpig[[threadgroup_position_in_grid]],
1001+
uint tpitg[[thread_position_in_threadgroup]],
1002+
uint ntg[[threads_per_threadgroup]]) {
1003+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1004+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1005+
1006+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1007+
const float x0 = src_row[i00];
1008+
const float x1 = src_row[i00 + args.ne00/2];
1009+
1010+
dst_row[i00] = x0*x1*(x0 > 0.0f);
1011+
}
1012+
}
1013+
1014+
kernel void kernel_geglu(
1015+
device const char * src0,
1016+
device char * dst,
1017+
constant ggml_metal_kargs_glu & args,
1018+
uint tgpig[[threadgroup_position_in_grid]],
1019+
uint tpitg[[thread_position_in_threadgroup]],
1020+
uint ntg[[threads_per_threadgroup]]) {
1021+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1022+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1023+
1024+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1025+
const float x0 = src_row[i00];
1026+
const float x1 = src_row[i00 + args.ne00/2];
1027+
1028+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1029+
1030+
dst_row[i00] = gelu*x1;
1031+
}
1032+
}
1033+
1034+
kernel void kernel_swiglu(
1035+
device const char * src0,
1036+
device char * dst,
1037+
constant ggml_metal_kargs_glu & args,
1038+
uint tgpig[[threadgroup_position_in_grid]],
1039+
uint tpitg[[thread_position_in_threadgroup]],
1040+
uint ntg[[threads_per_threadgroup]]) {
1041+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1042+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1043+
1044+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1045+
const float x0 = src_row[i00];
1046+
const float x1 = src_row[i00 + args.ne00/2];
1047+
1048+
const float silu = x0 / (1.0f + exp(-x0));
1049+
1050+
dst_row[i00] = silu*x1;
1051+
}
1052+
}
1053+
9961054
template <bool norm>
9971055
kernel void kernel_sum_rows(
9981056
constant ggml_metal_kargs_sum_rows & args,

0 commit comments

Comments
 (0)