31
31
enum ggml_metal_kernel_type {
32
32
GGML_METAL_KERNEL_TYPE_ADD,
33
33
GGML_METAL_KERNEL_TYPE_ADD_ROW,
34
+ GGML_METAL_KERNEL_TYPE_SUB,
35
+ GGML_METAL_KERNEL_TYPE_SUB_ROW,
34
36
GGML_METAL_KERNEL_TYPE_MUL,
35
37
GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
38
GGML_METAL_KERNEL_TYPE_DIV,
205
207
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206
208
GGML_METAL_KERNEL_TYPE_CONCAT,
207
209
GGML_METAL_KERNEL_TYPE_SQR,
210
+ GGML_METAL_KERNEL_TYPE_SQRT,
208
211
GGML_METAL_KERNEL_TYPE_SIN,
209
212
GGML_METAL_KERNEL_TYPE_COS,
210
213
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -493,6 +496,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
493
496
494
497
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
495
498
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
499
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
500
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
496
501
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
497
502
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true );
498
503
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
@@ -667,6 +672,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
667
672
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true );
668
673
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT, concat, true );
669
674
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR, sqr, true );
675
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true );
670
676
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
671
677
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
672
678
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
769
775
case GGML_OP_PERMUTE:
770
776
case GGML_OP_CONCAT:
771
777
case GGML_OP_ADD:
778
+ case GGML_OP_SUB:
772
779
case GGML_OP_ACC:
773
780
case GGML_OP_MUL:
774
781
case GGML_OP_DIV:
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
777
784
case GGML_OP_CLAMP:
778
785
return true ;
779
786
case GGML_OP_SQR:
787
+ case GGML_OP_SQRT:
780
788
case GGML_OP_SIN:
781
789
case GGML_OP_COS:
782
790
return ggml_is_contiguous (op->src [0 ]);
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
1057
1065
[encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1058
1066
} break ;
1059
1067
case GGML_OP_ADD:
1068
+ case GGML_OP_SUB:
1060
1069
case GGML_OP_MUL:
1061
1070
case GGML_OP_DIV:
1062
1071
{
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
1080
1089
nb = ne00 / 4 ;
1081
1090
switch (dst->op ) {
1082
1091
case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
1092
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
1083
1093
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
1084
1094
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
1085
1095
default : GGML_ABORT (" fatal error" );
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
1089
1099
} else {
1090
1100
switch (dst->op ) {
1091
1101
case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
1102
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
1092
1103
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
1093
1104
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
1094
1105
default : GGML_ABORT (" fatal error" );
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
1416
1427
1417
1428
const int64_t n = ggml_nelements (dst);
1418
1429
1430
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1431
+ } break ;
1432
+ case GGML_OP_SQRT:
1433
+ {
1434
+ GGML_ASSERT (ggml_is_contiguous (src0));
1435
+
1436
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SQRT].pipeline ;
1437
+
1438
+ [encoder setComputePipelineState: pipeline];
1439
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1440
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1441
+
1442
+ const int64_t n = ggml_nelements (dst);
1443
+
1419
1444
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1420
1445
} break ;
1421
1446
case GGML_OP_SIN:
0 commit comments