Skip to content

Commit f4be71e

Browse files
authored
refactor into GGML_GLU_OP
1 parent 5c58196 commit f4be71e

File tree

7 files changed

+172
-55
lines changed

7 files changed

+172
-55
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,8 @@ extern "C" {
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520520

521+
GGML_OP_GLU,
522+
521523
GGML_OP_COUNT,
522524
};
523525

@@ -537,13 +539,18 @@ extern "C" {
537539
GGML_UNARY_OP_HARDSIGMOID,
538540
GGML_UNARY_OP_EXP,
539541
GGML_UNARY_OP_GELU_ERF,
540-
GGML_UNARY_OP_REGLU,
541-
GGML_UNARY_OP_GEGLU,
542-
GGML_UNARY_OP_SWIGLU,
543542

544543
GGML_UNARY_OP_COUNT,
545544
};
546545

546+
enum ggml_glu_op {
547+
GGML_GLU_OP_REGLU,
548+
GGML_GLU_OP_GEGLU,
549+
GGML_GLU_OP_SWIGLU,
550+
551+
GGML_GLU_OP_COUNT,
552+
};
553+
547554
enum ggml_object_type {
548555
GGML_OBJECT_TYPE_TENSOR,
549556
GGML_OBJECT_TYPE_GRAPH,
@@ -659,6 +666,7 @@ extern "C" {
659666
GGML_API const char * ggml_op_symbol(enum ggml_op op);
660667

661668
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
669+
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
662670
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
663671

664672
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -760,6 +768,7 @@ extern "C" {
760768
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
761769

762770
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
771+
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
763772

764773
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
765774
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1088,6 +1097,14 @@ extern "C" {
10881097
struct ggml_context * ctx,
10891098
struct ggml_tensor * a);
10901099

1100+
// gated linear unit ops
1101+
// A: n columns, r rows,
1102+
// result is n / 2 columns, r rows,
1103+
GGML_API struct ggml_tensor * ggml_glu(
1104+
struct ggml_context * ctx,
1105+
struct ggml_tensor * a,
1106+
enum ggml_glu_op op);
1107+
10911108
GGML_API struct ggml_tensor * ggml_reglu(
10921109
struct ggml_context * ctx,
10931110
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,6 +2006,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20062006
{
20072007
ggml_compute_forward_unary(params, tensor);
20082008
} break;
2009+
case GGML_OP_GLU:
2010+
{
2011+
ggml_compute_forward_glu(params, tensor);
2012+
} break;
20092013
case GGML_OP_GET_REL_POS:
20102014
{
20112015
ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2209,9 +2213,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22092213
case GGML_UNARY_OP_GELU_ERF:
22102214
case GGML_UNARY_OP_GELU_QUICK:
22112215
case GGML_UNARY_OP_SILU:
2212-
case GGML_UNARY_OP_REGLU:
2213-
case GGML_UNARY_OP_GEGLU:
2214-
case GGML_UNARY_OP_SWIGLU:
2216+
{
2217+
n_tasks = n_threads;
2218+
} break;
2219+
default:
2220+
GGML_ABORT("fatal error");
2221+
}
2222+
break;
2223+
case GGML_OP_GLU:
2224+
switch (ggml_get_glu_op(node)) {
2225+
case GGML_GLU_OP_REGLU:
2226+
case GGML_GLU_OP_GEGLU:
2227+
case GGML_GLU_OP_SWIGLU:
22152228
{
22162229
n_tasks = n_threads;
22172230
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8241,15 +8241,31 @@ void ggml_compute_forward_unary(
82418241
{
82428242
ggml_compute_forward_exp(params, dst);
82438243
} break;
8244-
case GGML_UNARY_OP_REGLU:
8244+
default:
8245+
{
8246+
GGML_ABORT("fatal error");
8247+
}
8248+
}
8249+
}
8250+
8251+
//ggml_compute_forward_glu
8252+
8253+
void ggml_compute_forward_glu(
8254+
const ggml_compute_params * params,
8255+
ggml_tensor * dst) {
8256+
8257+
const ggml_glu_op op = ggml_get_glu_op(dst);
8258+
8259+
switch (op) {
8260+
case GGML_GLU_OP_REGLU:
82458261
{
82468262
ggml_compute_forward_reglu(params, dst);
82478263
} break;
8248-
case GGML_UNARY_OP_GEGLU:
8264+
case GGML_GLU_OP_GEGLU:
82498265
{
82508266
ggml_compute_forward_geglu(params, dst);
82518267
} break;
8252-
case GGML_UNARY_OP_SWIGLU:
8268+
case GGML_GLU_OP_SWIGLU:
82538269
{
82548270
ggml_compute_forward_swiglu(params, dst);
82558271
} break;

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
9292
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9393
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9494
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
95+
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9596
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9697
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9798
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,13 +2216,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22162216
case GGML_UNARY_OP_EXP:
22172217
ggml_cuda_op_exp(ctx, dst);
22182218
break;
2219-
case GGML_UNARY_OP_REGLU:
2219+
default:
2220+
return false;
2221+
}
2222+
break;
2223+
case GGML_OP_GLU:
2224+
switch (ggml_get_glu_op(dst)) {
2225+
case GGML_GLU_OP_REGLU:
22202226
ggml_cuda_op_reglu(ctx, dst);
22212227
break;
2222-
case GGML_UNARY_OP_GEGLU:
2228+
case GGML_GLU_OP_GEGLU:
22232229
ggml_cuda_op_geglu(ctx, dst);
22242230
break;
2225-
case GGML_UNARY_OP_SWIGLU:
2231+
case GGML_GLU_OP_SWIGLU:
22262232
ggml_cuda_op_swiglu(ctx, dst);
22272233
break;
22282234
default:
@@ -2996,9 +3002,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29963002
case GGML_UNARY_OP_TANH:
29973003
case GGML_UNARY_OP_EXP:
29983004
return ggml_is_contiguous(op->src[0]);
2999-
case GGML_UNARY_OP_REGLU:
3000-
case GGML_UNARY_OP_GEGLU:
3001-
case GGML_UNARY_OP_SWIGLU:
3005+
default:
3006+
return false;
3007+
}
3008+
break;
3009+
case GGML_OP_GLU:
3010+
switch (ggml_get_glu_op(op)) {
3011+
case GGML_GLU_OP_REGLU:
3012+
case GGML_GLU_OP_GEGLU:
3013+
case GGML_GLU_OP_SWIGLU:
30023014
return ggml_is_contiguous_1(op->src[0]);
30033015
default:
30043016
return false;

ggml/src/ggml.c

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
989989
"CROSS_ENTROPY_LOSS",
990990
"CROSS_ENTROPY_LOSS_BACK",
991991
"OPT_STEP_ADAMW",
992+
993+
"GLU",
992994
};
993995

994-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
996+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
995997

996998
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
997999
"none",
@@ -1084,9 +1086,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10841086
"cross_entropy_loss(x,y)",
10851087
"cross_entropy_loss_back(x,y)",
10861088
"adamw(x)",
1089+
1090+
"glu(x)",
10871091
};
10881092

1089-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1093+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
10901094

10911095
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10921096

@@ -1107,12 +1111,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11071111
"HARDSIGMOID",
11081112
"EXP",
11091113
"GELU_ERF",
1114+
};
1115+
1116+
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1117+
1118+
1119+
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
11101120
"REGLU",
11111121
"GEGLU",
11121122
"SWIGLU",
11131123
};
11141124

1115-
static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18");
1125+
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
11161126

11171127

11181128
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -1217,11 +1227,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
12171227
return GGML_UNARY_OP_NAME[op];
12181228
}
12191229

1230+
const char * ggml_glu_op_name(enum ggml_glu_op op) {
1231+
return GGML_GLU_OP_NAME[op];
1232+
}
1233+
12201234
const char * ggml_op_desc(const struct ggml_tensor * t) {
12211235
if (t->op == GGML_OP_UNARY) {
12221236
enum ggml_unary_op uop = ggml_get_unary_op(t);
12231237
return ggml_unary_op_name(uop);
12241238
}
1239+
if (t->op == GGML_OP_GLU) {
1240+
enum ggml_glu_op gop = ggml_get_glu_op(t);
1241+
return ggml_glu_op_name(gop);
1242+
}
12251243
return ggml_op_name(t->op);
12261244
}
12271245

@@ -1740,6 +1758,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
17401758
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
17411759
}
17421760

1761+
enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1762+
GGML_ASSERT(tensor->op == GGML_OP_GLU);
1763+
return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1764+
}
1765+
17431766
const char * ggml_get_name(const struct ggml_tensor * tensor) {
17441767
return tensor->name;
17451768
}
@@ -2619,58 +2642,47 @@ struct ggml_tensor * ggml_exp_inplace(
26192642
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
26202643
}
26212644

2622-
// ggml_reglu
2645+
// ggml_glu
26232646

2624-
struct ggml_tensor * ggml_reglu(
2647+
struct ggml_tensor * ggml_glu(
26252648
struct ggml_context * ctx,
2626-
struct ggml_tensor * a) {
2649+
struct ggml_tensor * a,
2650+
enum ggml_glu_op op) {
26272651
GGML_ASSERT(ggml_is_contiguous_1(a));
26282652

26292653
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
26302654
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
26312655

2632-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU);
2656+
ggml_set_op_params_i32(result, 0, (int32_t) op);
26332657

2634-
result->op = GGML_OP_UNARY;
2658+
result->op = GGML_OP_GLU;
26352659
result->src[0] = a;
26362660

26372661
return result;
26382662
}
26392663

2640-
// ggml_geglu
2664+
// ggml_reglu
26412665

2642-
struct ggml_tensor * ggml_geglu(
2666+
struct ggml_tensor * ggml_reglu(
26432667
struct ggml_context * ctx,
26442668
struct ggml_tensor * a) {
2645-
GGML_ASSERT(ggml_is_contiguous_1(a));
2646-
2647-
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2648-
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
2649-
2650-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU);
2669+
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
2670+
}
26512671

2652-
result->op = GGML_OP_UNARY;
2653-
result->src[0] = a;
2672+
// ggml_geglu
26542673

2655-
return result;
2674+
struct ggml_tensor * ggml_geglu(
2675+
struct ggml_context * ctx,
2676+
struct ggml_tensor * a) {
2677+
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
26562678
}
26572679

26582680
// ggml_swiglu
26592681

26602682
struct ggml_tensor * ggml_swiglu(
26612683
struct ggml_context * ctx,
26622684
struct ggml_tensor * a) {
2663-
GGML_ASSERT(ggml_is_contiguous_1(a));
2664-
2665-
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2666-
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
2667-
2668-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
2669-
2670-
result->op = GGML_OP_UNARY;
2671-
result->src[0] = a;
2672-
2673-
return result;
2685+
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
26742686
}
26752687

26762688
// ggml_norm

0 commit comments

Comments
 (0)