Skip to content

Commit 564861d

Browse files
committed
metal : add glu kernels
ggml-ci
1 parent f4be71e commit 564861d

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,12 @@ typedef struct {
422422
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423423
} ggml_metal_kargs_im2col;
424424

425+
typedef struct{
426+
int32_t ne00;
427+
uint64_t nb01;
428+
uint64_t nb1;
429+
} ggml_metal_kargs_glu;
430+
425431
typedef struct {
426432
int64_t ne00;
427433
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
497497
GGML_METAL_KERNEL_TYPE_SIN,
498498
GGML_METAL_KERNEL_TYPE_COS,
499499
GGML_METAL_KERNEL_TYPE_NEG,
500+
GGML_METAL_KERNEL_TYPE_REGLU,
501+
GGML_METAL_KERNEL_TYPE_GEGLU,
502+
GGML_METAL_KERNEL_TYPE_SWIGLU,
500503
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501504
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502505
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1453,6 +1456,9 @@ @implementation GGMLMetalClass
14531456
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
14541457
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
14551458
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1459+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1460+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1461+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
14561462
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
14571463
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
14581464
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1626,6 +1632,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16261632
default:
16271633
return false;
16281634
}
1635+
case GGML_OP_GLU:
1636+
switch (ggml_get_glu_op(op)) {
1637+
case GGML_GLU_OP_REGLU:
1638+
case GGML_GLU_OP_GEGLU:
1639+
case GGML_GLU_OP_SWIGLU:
1640+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1641+
default:
1642+
return false;
1643+
}
16291644
case GGML_OP_NONE:
16301645
case GGML_OP_RESHAPE:
16311646
case GGML_OP_VIEW:
@@ -2343,6 +2358,43 @@ static bool ggml_metal_encode_node(
23432358
GGML_ABORT("fatal error");
23442359
}
23452360
} break;
2361+
case GGML_OP_GLU:
2362+
{
2363+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2364+
2365+
id<MTLComputePipelineState> pipeline = nil;
2366+
2367+
switch (ggml_get_glu_op(node)) {
2368+
case GGML_GLU_OP_REGLU:
2369+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2370+
break;
2371+
case GGML_GLU_OP_GEGLU:
2372+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2373+
break;
2374+
case GGML_GLU_OP_SWIGLU:
2375+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2376+
break;
2377+
default:
2378+
GGML_ABORT("fatal error");
2379+
}
2380+
2381+
ggml_metal_kargs_glu args = {
2382+
/*.ne00 =*/ ne00,
2383+
/*.nb01 =*/ nb01,
2384+
/*.nb1 =*/ nb1,
2385+
};
2386+
2387+
[encoder setComputePipelineState:pipeline];
2388+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2389+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2390+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
2391+
2392+
const int64_t nrows = ggml_nrows(src0);
2393+
2394+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2395+
2396+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2397+
} break;
23462398
case GGML_OP_SQR:
23472399
{
23482400
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -2405,7 +2457,6 @@ static bool ggml_metal_encode_node(
24052457

24062458
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
24072459

2408-
24092460
ggml_metal_kargs_sum_rows args = {
24102461
/*.ne00 =*/ ne00,
24112462
/*.ne01 =*/ ne01,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,64 @@ kernel void kernel_neg(
993993
dst[tpig] = -src0[tpig];
994994
}
995995

996+
kernel void kernel_reglu(
997+
device const char * src0,
998+
device char * dst,
999+
constant ggml_metal_kargs_glu & args,
1000+
uint tgpig[[threadgroup_position_in_grid]],
1001+
uint tpitg[[thread_position_in_threadgroup]],
1002+
uint ntg[[threads_per_threadgroup]]) {
1003+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1004+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1005+
1006+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1007+
const float x0 = src_row[i00];
1008+
const float x1 = src_row[i00 + args.ne00/2];
1009+
1010+
dst_row[i00] = x0*x1*(x0 > 0.0f);
1011+
}
1012+
}
1013+
1014+
kernel void kernel_geglu(
1015+
device const char * src0,
1016+
device char * dst,
1017+
constant ggml_metal_kargs_glu & args,
1018+
uint tgpig[[threadgroup_position_in_grid]],
1019+
uint tpitg[[thread_position_in_threadgroup]],
1020+
uint ntg[[threads_per_threadgroup]]) {
1021+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1022+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1023+
1024+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1025+
const float x0 = src_row[i00];
1026+
const float x1 = src_row[i00 + args.ne00/2];
1027+
1028+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1029+
1030+
dst_row[i00] = gelu*x1;
1031+
}
1032+
}
1033+
1034+
kernel void kernel_swiglu(
1035+
device const char * src0,
1036+
device char * dst,
1037+
constant ggml_metal_kargs_glu & args,
1038+
uint tgpig[[threadgroup_position_in_grid]],
1039+
uint tpitg[[thread_position_in_threadgroup]],
1040+
uint ntg[[threads_per_threadgroup]]) {
1041+
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
1042+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1043+
1044+
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
1045+
const float x0 = src_row[i00];
1046+
const float x1 = src_row[i00 + args.ne00/2];
1047+
1048+
const float silu = x0 / (1.0f + exp(-x0));
1049+
1050+
dst_row[i00] = silu*x1;
1051+
}
1052+
}
1053+
9961054
kernel void kernel_sum_rows(
9971055
device const float * src0,
9981056
device float * dst,

0 commit comments

Comments
 (0)