@@ -497,6 +497,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
497
497
GGML_METAL_KERNEL_TYPE_SIN,
498
498
GGML_METAL_KERNEL_TYPE_COS,
499
499
GGML_METAL_KERNEL_TYPE_NEG,
500
+ GGML_METAL_KERNEL_TYPE_REGLU,
501
+ GGML_METAL_KERNEL_TYPE_GEGLU,
502
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
500
503
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
504
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
505
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1453,6 +1456,9 @@ @implementation GGMLMetalClass
1453
1456
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
1454
1457
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
1455
1458
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
1459
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REGLU, reglu, true );
1460
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true );
1461
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true );
1456
1462
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1457
1463
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
1458
1464
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
@@ -1626,6 +1632,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1626
1632
default :
1627
1633
return false ;
1628
1634
}
1635
+ case GGML_OP_GLU:
1636
+ switch (ggml_get_glu_op (op)) {
1637
+ case GGML_GLU_OP_REGLU:
1638
+ case GGML_GLU_OP_GEGLU:
1639
+ case GGML_GLU_OP_SWIGLU:
1640
+ return ggml_is_contiguous_1 (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
1641
+ default :
1642
+ return false ;
1643
+ }
1629
1644
case GGML_OP_NONE:
1630
1645
case GGML_OP_RESHAPE:
1631
1646
case GGML_OP_VIEW:
@@ -2343,6 +2358,43 @@ static bool ggml_metal_encode_node(
2343
2358
GGML_ABORT (" fatal error" );
2344
2359
}
2345
2360
} break ;
2361
+ case GGML_OP_GLU:
2362
+ {
2363
+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
2364
+
2365
+ id <MTLComputePipelineState > pipeline = nil ;
2366
+
2367
+ switch (ggml_get_glu_op (node)) {
2368
+ case GGML_GLU_OP_REGLU:
2369
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REGLU].pipeline ;
2370
+ break ;
2371
+ case GGML_GLU_OP_GEGLU:
2372
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GEGLU].pipeline ;
2373
+ break ;
2374
+ case GGML_GLU_OP_SWIGLU:
2375
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline ;
2376
+ break ;
2377
+ default :
2378
+ GGML_ABORT (" fatal error" );
2379
+ }
2380
+
2381
+ ggml_metal_kargs_glu args = {
2382
+ /* .ne00 =*/ ne00,
2383
+ /* .nb01 =*/ nb01,
2384
+ /* .nb1 =*/ nb1,
2385
+ };
2386
+
2387
+ [encoder setComputePipelineState: pipeline];
2388
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2389
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2390
+ [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2391
+
2392
+ const int64_t nrows = ggml_nrows (src0);
2393
+
2394
+ const int32_t nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00/2 );
2395
+
2396
+ [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
2397
+ } break ;
2346
2398
case GGML_OP_SQR:
2347
2399
{
2348
2400
GGML_ASSERT (ggml_is_contiguous (src0));
@@ -2405,7 +2457,6 @@ static bool ggml_metal_encode_node(
2405
2457
2406
2458
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2407
2459
2408
-
2409
2460
ggml_metal_kargs_sum_rows args = {
2410
2461
/* .ne00 =*/ ne00,
2411
2462
/* .ne01 =*/ ne01,
0 commit comments