@@ -481,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
481
481
GGML_METAL_KERNEL_TYPE_SQRT,
482
482
GGML_METAL_KERNEL_TYPE_SIN,
483
483
GGML_METAL_KERNEL_TYPE_COS,
484
+ GGML_METAL_KERNEL_TYPE_NEG,
484
485
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
485
486
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
486
487
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ @implementation GGMLMetalClass
1159
1160
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true );
1160
1161
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
1161
1162
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
1163
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
1162
1164
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1163
1165
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
1164
1166
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1320
1322
case GGML_UNARY_OP_GELU_QUICK:
1321
1323
case GGML_UNARY_OP_SILU:
1322
1324
case GGML_UNARY_OP_ELU:
1325
+ case GGML_UNARY_OP_NEG:
1323
1326
return ggml_is_contiguous (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
1324
1327
default :
1325
1328
return false ;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
2010
2013
2011
2014
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2012
2015
} break ;
2016
+ case GGML_UNARY_OP_NEG:
2017
+ {
2018
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_NEG].pipeline ;
2019
+
2020
+ [encoder setComputePipelineState: pipeline];
2021
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2022
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2023
+
2024
+ const int64_t n = ggml_nelements (dst);
2025
+
2026
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2027
+ } break ;
2013
2028
default :
2014
2029
{
2015
2030
GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments