@@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
173
173
GGML_METAL_KERNEL_TYPE_SILU,
174
174
GGML_METAL_KERNEL_TYPE_SILU_4,
175
175
GGML_METAL_KERNEL_TYPE_ELU,
176
+ GGML_METAL_KERNEL_TYPE_ABS,
177
+ GGML_METAL_KERNEL_TYPE_SGN,
178
+ GGML_METAL_KERNEL_TYPE_STEP,
179
+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
180
+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
181
+ GGML_METAL_KERNEL_TYPE_EXP,
176
182
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177
183
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178
184
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass
1155
1161
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
1156
1162
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
1157
1163
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ELU, elu, true );
1164
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ABS, abs, true );
1165
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SGN, sgn, true );
1166
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_STEP, step, true );
1167
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true );
1168
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true );
1169
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_EXP, exp, true );
1158
1170
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1159
1171
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1160
1172
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1688
1700
case GGML_UNARY_OP_SILU:
1689
1701
case GGML_UNARY_OP_ELU:
1690
1702
case GGML_UNARY_OP_NEG:
1703
+ case GGML_UNARY_OP_ABS:
1704
+ case GGML_UNARY_OP_SGN:
1705
+ case GGML_UNARY_OP_STEP:
1706
+ case GGML_UNARY_OP_HARDSWISH:
1707
+ case GGML_UNARY_OP_HARDSIGMOID:
1708
+ case GGML_UNARY_OP_EXP:
1691
1709
return ggml_is_contiguous (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
1692
1710
default :
1693
1711
return false ;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
2439
2457
2440
2458
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2441
2459
} break ;
2460
+ case GGML_UNARY_OP_ABS:
2461
+ {
2462
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ABS].pipeline ;
2463
+
2464
+ [encoder setComputePipelineState: pipeline];
2465
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2466
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2467
+
2468
+ const int64_t n = ggml_nelements (dst);
2469
+
2470
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2471
+ } break ;
2472
+ case GGML_UNARY_OP_SGN:
2473
+ {
2474
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SGN].pipeline ;
2475
+
2476
+ [encoder setComputePipelineState: pipeline];
2477
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2478
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2479
+
2480
+ const int64_t n = ggml_nelements (dst);
2481
+
2482
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2483
+ } break ;
2484
+ case GGML_UNARY_OP_STEP:
2485
+ {
2486
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_STEP].pipeline ;
2487
+
2488
+ [encoder setComputePipelineState: pipeline];
2489
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2490
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2491
+
2492
+ const int64_t n = ggml_nelements (dst);
2493
+
2494
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2495
+ } break ;
2496
+ case GGML_UNARY_OP_HARDSWISH:
2497
+ {
2498
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline ;
2499
+
2500
+ [encoder setComputePipelineState: pipeline];
2501
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2502
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2503
+
2504
+ const int64_t n = ggml_nelements (dst);
2505
+
2506
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2507
+ } break ;
2508
+ case GGML_UNARY_OP_HARDSIGMOID:
2509
+ {
2510
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline ;
2511
+
2512
+ [encoder setComputePipelineState: pipeline];
2513
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2514
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2515
+
2516
+ const int64_t n = ggml_nelements (dst);
2517
+
2518
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2519
+ } break ;
2520
+ case GGML_UNARY_OP_EXP:
2521
+ {
2522
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_EXP].pipeline ;
2523
+
2524
+ [encoder setComputePipelineState: pipeline];
2525
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2526
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2527
+
2528
+ const int64_t n = ggml_nelements (dst);
2529
+
2530
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2531
+ } break ;
2442
2532
default :
2443
2533
{
2444
2534
GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments