Skip to content

Commit a0deecd

Browse files
CISCggerganov0cc4mqnixsynapsejeffbolznv
committed
ggml : implement REGLU/GEGLU/SWIGLU ops (ggml-org#14158)
* implement unary REGLU/GEGLU/SWIGLU cpu ops * relax constraints * duplicate shape of source * fix ggml_vec_geglu_f16 * special case gated ops * implement unary REGLU/GEGLU/SWIGLU cuda ops * tighten constraints again * refactor into GGML_GLU_OP * metal : add glu kernels ggml-ci * add CUDA_GLU_BLOCK_SIZE [no ci] * more constraints and use 64bit ints ggml-ci * 64bit multiplication [no ci] * implement swapped variants (cpu/cuda) * update comment [no ci] ggml-ci * Vulkan: Add GLU ops and shaders * SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate * ggml : implement GLU for split up/gate (ggml-org#14181) * implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> * GGML: increase OP count in assertion * Refactor: Optimize SYCL element-wise operations with unary function inlining This commit refactors the SYCL element-wise operations to improve performance by: - Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead. - Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic. - Replacing direct kernel calls with calls to these inlined functions. - Using `__dpct_inline__` to encourage compiler inlining. - Minor code cleanup and consistency improvements. The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices. * vulkan: Increase workgroup size for GLU, for performance (ggml-org#14345) * vulkan: Increase workgroup size for GLU, for performance * vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup * merge fix * metal : add support for split and swap ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
1 parent cdaa419 commit a0deecd

File tree

10 files changed

+1202
-1071
lines changed

10 files changed

+1202
-1071
lines changed

ggml/src/ggml-cpu/vec.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -913,8 +913,8 @@ inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x,
913913

914914
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
915915
for (int i = 0; i < n; ++i) {
916-
float v = GGML_CPU_FP16_TO_FP32(x[i]);
917-
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
916+
float v = GGML_FP16_TO_FP32(x[i]);
917+
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
918918
}
919919
}
920920

@@ -927,9 +927,9 @@ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, c
927927
} else if (x[i] >= 10.0f) {
928928
y[i] = x[i] * g[i];
929929
} else {
930-
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
930+
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
931931
memcpy(&t, &fp16, sizeof(uint16_t));
932-
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
932+
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
933933
}
934934
}
935935
}
@@ -944,18 +944,18 @@ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, c
944944
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
945945
const uint16_t * i16 = (const uint16_t *) x;
946946
for (int i = 0; i < n; ++i) {
947-
float v = GGML_CPU_FP16_TO_FP32(g[i]);
948-
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
947+
float v = GGML_FP16_TO_FP32(g[i]);
948+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
949949
}
950950
}
951951

952952
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
953953

954954
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
955955
for (int i = 0; i < n; ++i) {
956-
float v = GGML_CPU_FP16_TO_FP32(x[i]);
957-
float w = GGML_CPU_FP16_TO_FP32(g[i]);
958-
y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
956+
float v = GGML_FP16_TO_FP32(x[i]);
957+
float w = GGML_FP16_TO_FP32(g[i]);
958+
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
959959
}
960960
}
961961

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23032303
return false;
23042304
}
23052305
break;
2306+
case GGML_OP_GLU:
2307+
switch (ggml_get_glu_op(dst)) {
2308+
case GGML_GLU_OP_REGLU:
2309+
ggml_cuda_op_reglu(ctx, dst);
2310+
break;
2311+
case GGML_GLU_OP_GEGLU:
2312+
ggml_cuda_op_geglu(ctx, dst);
2313+
break;
2314+
case GGML_GLU_OP_SWIGLU:
2315+
ggml_cuda_op_swiglu(ctx, dst);
2316+
break;
2317+
default:
2318+
return false;
2319+
}
2320+
break;
23062321
case GGML_OP_NORM:
23072322
ggml_cuda_op_norm(ctx, dst);
23082323
break;
@@ -3096,6 +3111,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30963111
return false;
30973112
}
30983113
break;
3114+
case GGML_OP_GLU:
3115+
switch (ggml_get_glu_op(op)) {
3116+
case GGML_GLU_OP_REGLU:
3117+
case GGML_GLU_OP_GEGLU:
3118+
case GGML_GLU_OP_SWIGLU:
3119+
return ggml_is_contiguous_1(op->src[0]);
3120+
default:
3121+
return false;
3122+
}
3123+
break;
30993124
case GGML_OP_MUL_MAT:
31003125
case GGML_OP_MUL_MAT_ID:
31013126
{

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
526526
GGML_METAL_KERNEL_TYPE_SIN,
527527
GGML_METAL_KERNEL_TYPE_COS,
528528
GGML_METAL_KERNEL_TYPE_NEG,
529+
GGML_METAL_KERNEL_TYPE_REGLU,
530+
GGML_METAL_KERNEL_TYPE_GEGLU,
531+
GGML_METAL_KERNEL_TYPE_SWIGLU,
529532
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
530533
GGML_METAL_KERNEL_TYPE_MEAN,
531534
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1502,6 +1505,9 @@ @implementation GGMLMetalClass
15021505
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
15031506
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
15041507
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1508+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1509+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1510+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
15051511
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
15061512
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
15071513
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16801686
default:
16811687
return false;
16821688
}
1689+
case GGML_OP_GLU:
1690+
switch (ggml_get_glu_op(op)) {
1691+
case GGML_GLU_OP_REGLU:
1692+
case GGML_GLU_OP_GEGLU:
1693+
case GGML_GLU_OP_SWIGLU:
1694+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1695+
default:
1696+
return false;
1697+
}
16831698
case GGML_OP_NONE:
16841699
case GGML_OP_RESHAPE:
16851700
case GGML_OP_VIEW:
@@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node(
24192434
GGML_ABORT("fatal error");
24202435
}
24212436
} break;
2437+
case GGML_OP_GLU:
2438+
{
2439+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2440+
2441+
if (src1) {
2442+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
2443+
}
2444+
2445+
id<MTLComputePipelineState> pipeline = nil;
2446+
2447+
switch (ggml_get_glu_op(node)) {
2448+
case GGML_GLU_OP_REGLU:
2449+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2450+
break;
2451+
case GGML_GLU_OP_GEGLU:
2452+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2453+
break;
2454+
case GGML_GLU_OP_SWIGLU:
2455+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2456+
break;
2457+
default:
2458+
GGML_ABORT("fatal error");
2459+
}
2460+
2461+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
2462+
2463+
const int32_t i00 = swp ? ne0 : 0;
2464+
const int32_t i10 = swp ? 0 : ne0;
2465+
2466+
ggml_metal_kargs_glu args = {
2467+
/*.ne00 =*/ ne00,
2468+
/*.nb01 =*/ nb01,
2469+
/*.ne10 =*/ src1 ? ne10 : ne00,
2470+
/*.nb11 =*/ src1 ? nb11 : nb01,
2471+
/*.ne0 =*/ ne0,
2472+
/*.nb1 =*/ nb1,
2473+
/*.i00 =*/ src1 ? 0 : i00,
2474+
/*.i10 =*/ src1 ? 0 : i10,
2475+
};
2476+
2477+
[encoder setComputePipelineState:pipeline];
2478+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2479+
if (src1) {
2480+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2481+
} else {
2482+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2483+
}
2484+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2485+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
2486+
2487+
const int64_t nrows = ggml_nrows(src0);
2488+
2489+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2490+
2491+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2492+
} break;
24222493
case GGML_OP_SQR:
24232494
{
24242495
GGML_ASSERT(ggml_is_contiguous(src0));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,70 @@ kernel void kernel_neg(
11911191
dst[tpig] = -src0[tpig];
11921192
}
11931193

1194+
kernel void kernel_reglu(
1195+
device const char * src0,
1196+
device const char * src1,
1197+
device char * dst,
1198+
constant ggml_metal_kargs_glu & args,
1199+
uint tgpig[[threadgroup_position_in_grid]],
1200+
uint tpitg[[thread_position_in_threadgroup]],
1201+
uint ntg[[threads_per_threadgroup]]) {
1202+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1203+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1204+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1205+
1206+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1207+
const float x0 = src0_row[i0];
1208+
const float x1 = src1_row[i0];
1209+
1210+
dst_row[i0] = x0*x1*(x0 > 0.0f);
1211+
}
1212+
}
1213+
1214+
kernel void kernel_geglu(
1215+
device const char * src0,
1216+
device const char * src1,
1217+
device char * dst,
1218+
constant ggml_metal_kargs_glu & args,
1219+
uint tgpig[[threadgroup_position_in_grid]],
1220+
uint tpitg[[thread_position_in_threadgroup]],
1221+
uint ntg[[threads_per_threadgroup]]) {
1222+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1223+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1224+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1225+
1226+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1227+
const float x0 = src0_row[i0];
1228+
const float x1 = src1_row[i0];
1229+
1230+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1231+
1232+
dst_row[i0] = gelu*x1;
1233+
}
1234+
}
1235+
1236+
kernel void kernel_swiglu(
1237+
device const char * src0,
1238+
device const char * src1,
1239+
device char * dst,
1240+
constant ggml_metal_kargs_glu & args,
1241+
uint tgpig[[threadgroup_position_in_grid]],
1242+
uint tpitg[[thread_position_in_threadgroup]],
1243+
uint ntg[[threads_per_threadgroup]]) {
1244+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1245+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1246+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1247+
1248+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1249+
const float x0 = src0_row[i0];
1250+
const float x1 = src1_row[i0];
1251+
1252+
const float silu = x0 / (1.0f + exp(-x0));
1253+
1254+
dst_row[i0] = silu*x1;
1255+
}
1256+
}
1257+
11941258
template <bool norm>
11951259
kernel void kernel_sum_rows(
11961260
constant ggml_metal_kargs_sum_rows & args,

0 commit comments

Comments
 (0)