Skip to content

Commit dcf7f2e

Browse files
authored
metal : Add missing unary ops Metal support (#14660)
1 parent 84b396e commit dcf7f2e

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
173173
GGML_METAL_KERNEL_TYPE_SILU,
174174
GGML_METAL_KERNEL_TYPE_SILU_4,
175175
GGML_METAL_KERNEL_TYPE_ELU,
176+
GGML_METAL_KERNEL_TYPE_ABS,
177+
GGML_METAL_KERNEL_TYPE_SGN,
178+
GGML_METAL_KERNEL_TYPE_STEP,
179+
GGML_METAL_KERNEL_TYPE_HARDSWISH,
180+
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
181+
GGML_METAL_KERNEL_TYPE_EXP,
176182
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177183
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178184
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass
11551161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
11561162
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
11571163
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1164+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1165+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1166+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
1167+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1168+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1169+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
11581170
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
11591171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
11601172
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16881700
case GGML_UNARY_OP_SILU:
16891701
case GGML_UNARY_OP_ELU:
16901702
case GGML_UNARY_OP_NEG:
1703+
case GGML_UNARY_OP_ABS:
1704+
case GGML_UNARY_OP_SGN:
1705+
case GGML_UNARY_OP_STEP:
1706+
case GGML_UNARY_OP_HARDSWISH:
1707+
case GGML_UNARY_OP_HARDSIGMOID:
1708+
case GGML_UNARY_OP_EXP:
16911709
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
16921710
default:
16931711
return false;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
24392457

24402458
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
24412459
} break;
2460+
case GGML_UNARY_OP_ABS:
2461+
{
2462+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2463+
2464+
[encoder setComputePipelineState:pipeline];
2465+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2466+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2467+
2468+
const int64_t n = ggml_nelements(dst);
2469+
2470+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2471+
} break;
2472+
case GGML_UNARY_OP_SGN:
2473+
{
2474+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2475+
2476+
[encoder setComputePipelineState:pipeline];
2477+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2478+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2479+
2480+
const int64_t n = ggml_nelements(dst);
2481+
2482+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2483+
} break;
2484+
case GGML_UNARY_OP_STEP:
2485+
{
2486+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2487+
2488+
[encoder setComputePipelineState:pipeline];
2489+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2490+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2491+
2492+
const int64_t n = ggml_nelements(dst);
2493+
2494+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2495+
} break;
2496+
case GGML_UNARY_OP_HARDSWISH:
2497+
{
2498+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2499+
2500+
[encoder setComputePipelineState:pipeline];
2501+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2502+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2503+
2504+
const int64_t n = ggml_nelements(dst);
2505+
2506+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2507+
} break;
2508+
case GGML_UNARY_OP_HARDSIGMOID:
2509+
{
2510+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2511+
2512+
[encoder setComputePipelineState:pipeline];
2513+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2514+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2515+
2516+
const int64_t n = ggml_nelements(dst);
2517+
2518+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2519+
} break;
2520+
case GGML_UNARY_OP_EXP:
2521+
{
2522+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2523+
2524+
[encoder setComputePipelineState:pipeline];
2525+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2526+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2527+
2528+
const int64_t n = ggml_nelements(dst);
2529+
2530+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2531+
} break;
24422532
default:
24432533
{
24442534
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,51 @@ kernel void kernel_neg(
11991199
dst[tpig] = -src0[tpig];
12001200
}
12011201

1202+
kernel void kernel_abs(
1203+
device const float * src0,
1204+
device float * dst,
1205+
uint tpig[[thread_position_in_grid]]) {
1206+
dst[tpig] = fabs(src0[tpig]);
1207+
}
1208+
1209+
kernel void kernel_sgn(
1210+
device const float * src0,
1211+
device float * dst,
1212+
uint tpig[[thread_position_in_grid]]) {
1213+
device const float & x = src0[tpig];
1214+
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
1215+
}
1216+
1217+
kernel void kernel_step(
1218+
device const float * src0,
1219+
device float * dst,
1220+
uint tpig[[thread_position_in_grid]]) {
1221+
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
1222+
}
1223+
1224+
kernel void kernel_hardswish(
1225+
device const float * src0,
1226+
device float * dst,
1227+
uint tpig[[thread_position_in_grid]]) {
1228+
device const float & x = src0[tpig];
1229+
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1230+
}
1231+
1232+
kernel void kernel_hardsigmoid(
1233+
device const float * src0,
1234+
device float * dst,
1235+
uint tpig[[thread_position_in_grid]]) {
1236+
device const float & x = src0[tpig];
1237+
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1238+
}
1239+
1240+
kernel void kernel_exp(
1241+
device const float * src0,
1242+
device float * dst,
1243+
uint tpig[[thread_position_in_grid]]) {
1244+
dst[tpig] = exp(src0[tpig]);
1245+
}
1246+
12021247
kernel void kernel_reglu(
12031248
device const char * src0,
12041249
device const char * src1,

0 commit comments

Comments
 (0)