Skip to content

Commit a341aa3

Browse files
CISCqnixsynapse
authored andcommitted
refactor into GGML_GLU_OP
1 parent f8c2080 commit a341aa3

File tree

7 files changed

+170
-53
lines changed

7 files changed

+170
-53
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ extern "C" {
519519
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
520520
GGML_OP_OPT_STEP_ADAMW,
521521

522+
GGML_OP_GLU,
523+
522524
GGML_OP_COUNT,
523525
};
524526

@@ -538,13 +540,18 @@ extern "C" {
538540
GGML_UNARY_OP_HARDSIGMOID,
539541
GGML_UNARY_OP_EXP,
540542
GGML_UNARY_OP_GELU_ERF,
541-
GGML_UNARY_OP_REGLU,
542-
GGML_UNARY_OP_GEGLU,
543-
GGML_UNARY_OP_SWIGLU,
544543

545544
GGML_UNARY_OP_COUNT,
546545
};
547546

547+
enum ggml_glu_op {
548+
GGML_GLU_OP_REGLU,
549+
GGML_GLU_OP_GEGLU,
550+
GGML_GLU_OP_SWIGLU,
551+
552+
GGML_GLU_OP_COUNT,
553+
};
554+
548555
enum ggml_object_type {
549556
GGML_OBJECT_TYPE_TENSOR,
550557
GGML_OBJECT_TYPE_GRAPH,
@@ -660,6 +667,7 @@ extern "C" {
660667
GGML_API const char * ggml_op_symbol(enum ggml_op op);
661668

662669
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
670+
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
663671
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
664672

665673
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -761,6 +769,7 @@ extern "C" {
761769
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
762770

763771
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
772+
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
764773

765774
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
766775
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1089,6 +1098,14 @@ extern "C" {
10891098
struct ggml_context * ctx,
10901099
struct ggml_tensor * a);
10911100

1101+
// gated linear unit ops
1102+
// A: n columns, r rows,
1103+
// result is n / 2 columns, r rows,
1104+
GGML_API struct ggml_tensor * ggml_glu(
1105+
struct ggml_context * ctx,
1106+
struct ggml_tensor * a,
1107+
enum ggml_glu_op op);
1108+
10921109
GGML_API struct ggml_tensor * ggml_reglu(
10931110
struct ggml_context * ctx,
10941111
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19411941
{
19421942
ggml_compute_forward_unary(params, tensor);
19431943
} break;
1944+
case GGML_OP_GLU:
1945+
{
1946+
ggml_compute_forward_glu(params, tensor);
1947+
} break;
19441948
case GGML_OP_GET_REL_POS:
19451949
{
19461950
ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2144,9 +2148,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21442148
case GGML_UNARY_OP_GELU_ERF:
21452149
case GGML_UNARY_OP_GELU_QUICK:
21462150
case GGML_UNARY_OP_SILU:
2147-
case GGML_UNARY_OP_REGLU:
2148-
case GGML_UNARY_OP_GEGLU:
2149-
case GGML_UNARY_OP_SWIGLU:
2151+
{
2152+
n_tasks = n_threads;
2153+
} break;
2154+
default:
2155+
GGML_ABORT("fatal error");
2156+
}
2157+
break;
2158+
case GGML_OP_GLU:
2159+
switch (ggml_get_glu_op(node)) {
2160+
case GGML_GLU_OP_REGLU:
2161+
case GGML_GLU_OP_GEGLU:
2162+
case GGML_GLU_OP_SWIGLU:
21502163
{
21512164
n_tasks = n_threads;
21522165
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8308,15 +8308,31 @@ void ggml_compute_forward_unary(
83088308
{
83098309
ggml_compute_forward_exp(params, dst);
83108310
} break;
8311-
case GGML_UNARY_OP_REGLU:
8311+
default:
8312+
{
8313+
GGML_ABORT("fatal error");
8314+
}
8315+
}
8316+
}
8317+
8318+
//ggml_compute_forward_glu
8319+
8320+
void ggml_compute_forward_glu(
8321+
const ggml_compute_params * params,
8322+
ggml_tensor * dst) {
8323+
8324+
const ggml_glu_op op = ggml_get_glu_op(dst);
8325+
8326+
switch (op) {
8327+
case GGML_GLU_OP_REGLU:
83128328
{
83138329
ggml_compute_forward_reglu(params, dst);
83148330
} break;
8315-
case GGML_UNARY_OP_GEGLU:
8331+
case GGML_GLU_OP_GEGLU:
83168332
{
83178333
ggml_compute_forward_geglu(params, dst);
83188334
} break;
8319-
case GGML_UNARY_OP_SWIGLU:
8335+
case GGML_GLU_OP_SWIGLU:
83208336
{
83218337
ggml_compute_forward_swiglu(params, dst);
83228338
} break;

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
9393
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9494
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9595
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
96+
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9697
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9798
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9899
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,13 +2246,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22462246
case GGML_UNARY_OP_EXP:
22472247
ggml_cuda_op_exp(ctx, dst);
22482248
break;
2249-
case GGML_UNARY_OP_REGLU:
2249+
default:
2250+
return false;
2251+
}
2252+
break;
2253+
case GGML_OP_GLU:
2254+
switch (ggml_get_glu_op(dst)) {
2255+
case GGML_GLU_OP_REGLU:
22502256
ggml_cuda_op_reglu(ctx, dst);
22512257
break;
2252-
case GGML_UNARY_OP_GEGLU:
2258+
case GGML_GLU_OP_GEGLU:
22532259
ggml_cuda_op_geglu(ctx, dst);
22542260
break;
2255-
case GGML_UNARY_OP_SWIGLU:
2261+
case GGML_GLU_OP_SWIGLU:
22562262
ggml_cuda_op_swiglu(ctx, dst);
22572263
break;
22582264
default:
@@ -3048,9 +3054,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30483054
case GGML_UNARY_OP_TANH:
30493055
case GGML_UNARY_OP_EXP:
30503056
return ggml_is_contiguous(op->src[0]);
3051-
case GGML_UNARY_OP_REGLU:
3052-
case GGML_UNARY_OP_GEGLU:
3053-
case GGML_UNARY_OP_SWIGLU:
3057+
default:
3058+
return false;
3059+
}
3060+
break;
3061+
case GGML_OP_GLU:
3062+
switch (ggml_get_glu_op(op)) {
3063+
case GGML_GLU_OP_REGLU:
3064+
case GGML_GLU_OP_GEGLU:
3065+
case GGML_GLU_OP_SWIGLU:
30543066
return ggml_is_contiguous_1(op->src[0]);
30553067
default:
30563068
return false;

ggml/src/ggml.c

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
984984
"CROSS_ENTROPY_LOSS",
985985
"CROSS_ENTROPY_LOSS_BACK",
986986
"OPT_STEP_ADAMW",
987+
988+
"GLU",
987989
};
988990

989991
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
@@ -1080,6 +1082,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10801082
"cross_entropy_loss(x,y)",
10811083
"cross_entropy_loss_back(x,y)",
10821084
"adamw(x)",
1085+
1086+
"glu(x)",
10831087
};
10841088

10851089
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
@@ -1103,12 +1107,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11031107
"HARDSIGMOID",
11041108
"EXP",
11051109
"GELU_ERF",
1110+
};
1111+
1112+
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1113+
1114+
1115+
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
11061116
"REGLU",
11071117
"GEGLU",
11081118
"SWIGLU",
11091119
};
11101120

1111-
static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18");
1121+
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
11121122

11131123

11141124
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -1213,11 +1223,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
12131223
return GGML_UNARY_OP_NAME[op];
12141224
}
12151225

1226+
const char * ggml_glu_op_name(enum ggml_glu_op op) {
1227+
return GGML_GLU_OP_NAME[op];
1228+
}
1229+
12161230
const char * ggml_op_desc(const struct ggml_tensor * t) {
12171231
if (t->op == GGML_OP_UNARY) {
12181232
enum ggml_unary_op uop = ggml_get_unary_op(t);
12191233
return ggml_unary_op_name(uop);
12201234
}
1235+
if (t->op == GGML_OP_GLU) {
1236+
enum ggml_glu_op gop = ggml_get_glu_op(t);
1237+
return ggml_glu_op_name(gop);
1238+
}
12211239
return ggml_op_name(t->op);
12221240
}
12231241

@@ -1736,6 +1754,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
17361754
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
17371755
}
17381756

1757+
enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1758+
GGML_ASSERT(tensor->op == GGML_OP_GLU);
1759+
return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1760+
}
1761+
17391762
const char * ggml_get_name(const struct ggml_tensor * tensor) {
17401763
return tensor->name;
17411764
}
@@ -2615,58 +2638,47 @@ struct ggml_tensor * ggml_exp_inplace(
26152638
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
26162639
}
26172640

2618-
// ggml_reglu
2641+
// ggml_glu
26192642

2620-
struct ggml_tensor * ggml_reglu(
2643+
struct ggml_tensor * ggml_glu(
26212644
struct ggml_context * ctx,
2622-
struct ggml_tensor * a) {
2645+
struct ggml_tensor * a,
2646+
enum ggml_glu_op op) {
26232647
GGML_ASSERT(ggml_is_contiguous_1(a));
26242648

26252649
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
26262650
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
26272651

2628-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU);
2652+
ggml_set_op_params_i32(result, 0, (int32_t) op);
26292653

2630-
result->op = GGML_OP_UNARY;
2654+
result->op = GGML_OP_GLU;
26312655
result->src[0] = a;
26322656

26332657
return result;
26342658
}
26352659

2636-
// ggml_geglu
2660+
// ggml_reglu
26372661

2638-
struct ggml_tensor * ggml_geglu(
2662+
struct ggml_tensor * ggml_reglu(
26392663
struct ggml_context * ctx,
26402664
struct ggml_tensor * a) {
2641-
GGML_ASSERT(ggml_is_contiguous_1(a));
2642-
2643-
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2644-
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
2645-
2646-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU);
2665+
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
2666+
}
26472667

2648-
result->op = GGML_OP_UNARY;
2649-
result->src[0] = a;
2668+
// ggml_geglu
26502669

2651-
return result;
2670+
struct ggml_tensor * ggml_geglu(
2671+
struct ggml_context * ctx,
2672+
struct ggml_tensor * a) {
2673+
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
26522674
}
26532675

26542676
// ggml_swiglu
26552677

26562678
struct ggml_tensor * ggml_swiglu(
26572679
struct ggml_context * ctx,
26582680
struct ggml_tensor * a) {
2659-
GGML_ASSERT(ggml_is_contiguous_1(a));
2660-
2661-
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2662-
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
2663-
2664-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
2665-
2666-
result->op = GGML_OP_UNARY;
2667-
result->src[0] = a;
2668-
2669-
return result;
2681+
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
26702682
}
26712683

26722684
// ggml_norm

0 commit comments

Comments
 (0)