Skip to content

Commit f4ca3e2

Browse files
jmorgancaggerganov
authored andcommitted
metal: add neg operator (llama/13029)
1 parent 0287a5c commit f4ca3e2

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
481481
GGML_METAL_KERNEL_TYPE_SQRT,
482482
GGML_METAL_KERNEL_TYPE_SIN,
483483
GGML_METAL_KERNEL_TYPE_COS,
484+
GGML_METAL_KERNEL_TYPE_NEG,
484485
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
485486
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
486487
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ @implementation GGMLMetalClass
11591160
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
11601161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
11611162
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1163+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
11621164
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
11631165
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
11641166
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13201322
case GGML_UNARY_OP_GELU_QUICK:
13211323
case GGML_UNARY_OP_SILU:
13221324
case GGML_UNARY_OP_ELU:
1325+
case GGML_UNARY_OP_NEG:
13231326
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
13241327
default:
13251328
return false;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
20102013

20112014
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
20122015
} break;
2016+
case GGML_UNARY_OP_NEG:
2017+
{
2018+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2019+
2020+
[encoder setComputePipelineState:pipeline];
2021+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2022+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2023+
2024+
const int64_t n = ggml_nelements(dst);
2025+
2026+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2027+
} break;
20132028
default:
20142029
{
20152030
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,13 @@ kernel void kernel_cos(
949949
dst[tpig] = cos(src0[tpig]);
950950
}
951951

952+
kernel void kernel_neg(
953+
device const float * src0,
954+
device float * dst,
955+
uint tpig[[thread_position_in_grid]]) {
956+
dst[tpig] = -src0[tpig];
957+
}
958+
952959
kernel void kernel_sum_rows(
953960
device const float * src0,
954961
device float * dst,

0 commit comments

Comments
 (0)