Skip to content

Commit 81cfb43

Browse files
huydt84huydt-bti
authored andcommitted
add geglu activation function (ggml-org#14074)
Co-authored-by: dinhhuy <huy.dinh@brains-tech.co.jp>
1 parent 05322ab commit 81cfb43

File tree

2 files changed

+80
-126
lines changed

2 files changed

+80
-126
lines changed

src/llama-graph.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,28 @@ ggml_tensor * llm_graph_context::build_ffn(
613613
cur = ggml_reglu(ctx0, cur);
614614
cb(cur, "ffn_reglu", il);
615615
} break;
616+
case LLM_FFN_GEGLU:
617+
{
618+
// Split into two equal parts
619+
int64_t split_point = cur->ne[0] / 2;
620+
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
621+
ctx0, cur, split_point,
622+
cur->ne[1], cur->nb[1], 0
623+
));
624+
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
625+
ctx0, cur, split_point,
626+
cur->ne[1], cur->nb[1],
627+
split_point * ggml_element_size(cur)
628+
));
629+
630+
// Apply GELU activation function to the first part
631+
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
632+
cb(output_ffn_up, "ffn_gelu", il);
633+
634+
// Element-wise multiplication between the activated part and the gate part
635+
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
636+
cb(cur, "ffn_geglu", il);
637+
} break;
616638
}
617639

618640
if (gate && type_gate == LLM_FFN_PAR) {

src/llama-graph.h

Lines changed: 58 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ struct ggml_tensor;
1717
struct llama_ubatch;
1818
struct llama_cparams;
1919

20-
struct llama_memory_context_i;
20+
struct llama_memory_state_i;
2121

22-
class llama_kv_cache_unified_context;
23-
class llama_kv_cache_unified_iswa_context;
24-
class llama_memory_recurrent_context;
25-
class llama_memory_hybrid_context;
22+
class llama_kv_cache_unified_state;
23+
class llama_kv_cache_unified_iswa_state;
24+
class llama_kv_cache_recurrent_state;
2625

2726
// certain models (typically multi-modal) can produce different types of graphs
2827
enum llm_graph_type {
@@ -38,7 +37,6 @@ enum llm_ffn_op_type {
3837
LLM_FFN_RELU_SQR,
3938
LLM_FFN_SWIGLU,
4039
LLM_FFN_GEGLU,
41-
LLM_FFN_REGLU,
4240
};
4341

4442
enum llm_ffn_gate_type {
@@ -96,14 +94,14 @@ class llm_graph_input_embd : public llm_graph_input_i {
9694

9795
class llm_graph_input_pos : public llm_graph_input_i {
9896
public:
99-
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
97+
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
10098
virtual ~llm_graph_input_pos() = default;
10199

102100
void set_input(const llama_ubatch * ubatch) override;
103101

104102
ggml_tensor * pos = nullptr; // I32 [n_batch]
105103

106-
const uint32_t n_pos_per_embd = 1;
104+
const int64_t n_pos_per_embd = 1;
107105
};
108106

109107
// temperature tuning, used by llama4
@@ -137,16 +135,15 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
137135
public:
138136
llm_graph_input_pos_bucket_kv(
139137
const llama_hparams & hparams,
140-
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
138+
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
141139
virtual ~llm_graph_input_pos_bucket_kv() = default;
142140

143141
void set_input(const llama_ubatch * ubatch) override;
144142

145143
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
146144

147145
const llama_hparams & hparams;
148-
149-
const llama_kv_cache_unified_context * mctx;
146+
const llama_kv_cache_unified_state * kv_state;
150147
};
151148

152149
class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,16 +188,28 @@ class llm_graph_input_cls : public llm_graph_input_i {
191188
const llama_cparams & cparams;
192189
};
193190

194-
class llm_graph_input_rs : public llm_graph_input_i {
191+
class llm_graph_input_s_copy : public llm_graph_input_i {
195192
public:
196-
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
197-
virtual ~llm_graph_input_rs() = default;
193+
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
194+
virtual ~llm_graph_input_s_copy() = default;
198195

199196
void set_input(const llama_ubatch * ubatch) override;
200197

201198
ggml_tensor * s_copy; // I32 [kv_size]
202199

203-
const llama_memory_recurrent_context * mctx;
200+
const llama_kv_cache_recurrent_state * kv_state;
201+
};
202+
203+
class llm_graph_input_s_mask : public llm_graph_input_i {
204+
public:
205+
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
206+
virtual ~llm_graph_input_s_mask() = default;
207+
208+
void set_input(const llama_ubatch * ubatch) override;
209+
210+
ggml_tensor * s_mask; // F32 [1, n_kv]
211+
212+
const llama_kv_cache_recurrent_state * kv_state;
204213
};
205214

206215
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -240,10 +249,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
240249
llm_graph_input_attn_kv_unified(
241250
const llama_hparams & hparams,
242251
const llama_cparams & cparams,
243-
const llama_kv_cache_unified_context * mctx) :
252+
const llama_kv_cache_unified_state * kv_state) :
244253
hparams(hparams),
245254
cparams(cparams),
246-
mctx(mctx) {
255+
kv_state(kv_state) {
247256
}
248257
~llm_graph_input_attn_kv_unified() = default;
249258

@@ -257,18 +266,18 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257266
const llama_hparams & hparams;
258267
const llama_cparams & cparams;
259268

260-
const llama_kv_cache_unified_context * mctx;
269+
const llama_kv_cache_unified_state * kv_state;
261270
};
262271

263272
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
264273
public:
265274
llm_graph_input_attn_kv_unified_iswa(
266275
const llama_hparams & hparams,
267276
const llama_cparams & cparams,
268-
const llama_kv_cache_unified_iswa_context * mctx) :
277+
const llama_kv_cache_unified_iswa_state * kv_state) :
269278
hparams(hparams),
270279
cparams(cparams),
271-
mctx(mctx) {
280+
kv_state(kv_state) {
272281
}
273282
~llm_graph_input_attn_kv_unified_iswa() = default;
274283

@@ -285,7 +294,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
285294
const llama_hparams & hparams;
286295
const llama_cparams & cparams;
287296

288-
const llama_kv_cache_unified_iswa_context * mctx;
297+
const llama_kv_cache_unified_iswa_state * kv_state;
289298
};
290299

291300
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -303,44 +312,6 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
303312
const llama_cross * cross = nullptr;
304313
};
305314

306-
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307-
public:
308-
llm_graph_input_mem_hybrid(
309-
const llama_hparams & hparams,
310-
const llama_cparams & cparams,
311-
const llama_memory_hybrid_context * mctx) :
312-
hparams(hparams),
313-
cparams(cparams),
314-
mctx(mctx) {
315-
}
316-
virtual ~llm_graph_input_mem_hybrid() = default;
317-
318-
void set_input(const llama_ubatch * ubatch) override;
319-
320-
ggml_tensor * s_copy; // I32 [kv_size]
321-
322-
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
323-
324-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
326-
327-
const llama_hparams & hparams;
328-
const llama_cparams & cparams;
329-
330-
const llama_memory_hybrid_context * mctx;
331-
};
332-
333-
// TODO: remove this when ggml_scale_add is implemented
334-
class llm_graph_input_one : public llm_graph_input_i {
335-
public:
336-
llm_graph_input_one() {}
337-
virtual ~llm_graph_input_one() = default;
338-
339-
void set_input(const llama_ubatch *) override;
340-
341-
ggml_tensor * one = nullptr; // F32
342-
};
343-
344315
//
345316
// llm_graph_result
346317
//
@@ -414,12 +385,12 @@ struct llm_graph_params {
414385
ggml_backend_sched_t sched;
415386
ggml_backend_t backend_cpu;
416387

417-
const llama_adapter_cvec * cvec;
418-
const llama_adapter_loras * loras;
419-
const llama_memory_context_i * mctx;
420-
const llama_cross * cross;
388+
const llama_adapter_cvec * cvec;
389+
const llama_adapter_loras * loras;
390+
const llama_memory_state_i * mstate;
391+
const llama_cross * cross;
421392

422-
uint32_t n_outputs;
393+
int32_t n_outputs;
423394

424395
const llm_graph_cb & cb;
425396
};
@@ -453,8 +424,8 @@ struct llm_graph_context {
453424
const float norm_eps;
454425
const float norm_rms_eps;
455426

456-
const int64_t n_tokens;
457-
const int64_t n_outputs;
427+
const int32_t n_tokens;
428+
const int32_t n_outputs;
458429
const int32_t n_ctx_orig; // yarn
459430

460431
const enum llama_pooling_type pooling_type;
@@ -466,17 +437,18 @@ struct llm_graph_context {
466437

467438
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
468439

469-
const llama_adapter_cvec * cvec;
470-
const llama_adapter_loras * loras;
471-
const llama_memory_context_i * mctx;
472-
const llama_cross * cross;
440+
const llama_adapter_cvec * cvec;
441+
const llama_adapter_loras * loras;
442+
const llama_memory_state_i * mstate;
443+
const llama_cross * cross;
473444

474445
const llm_graph_cb & cb_func;
475446

476447
std::unique_ptr<llm_graph_result> res;
477448

478449
llm_graph_context(const llm_graph_params & params);
479-
virtual ~llm_graph_context() = default;
450+
451+
int64_t n_pos_per_embd() const;
480452

481453
void cb(ggml_tensor * cur, const char * name, int il) const;
482454

@@ -548,14 +520,14 @@ struct llm_graph_context {
548520
ggml_tensor * build_inp_out_ids() const;
549521
ggml_tensor * build_inp_mean() const;
550522
ggml_tensor * build_inp_cls() const;
523+
ggml_tensor * build_inp_s_copy() const;
524+
ggml_tensor * build_inp_s_mask() const;
551525

552526
ggml_tensor * build_inp_cross_embd() const;
553527
ggml_tensor * build_inp_pos_bucket_enc() const;
554528
ggml_tensor * build_inp_pos_bucket_dec() const;
555529
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
556530

557-
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
558-
559531
//
560532
// attention
561533
//
@@ -602,15 +574,14 @@ struct llm_graph_context {
602574

603575
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
604576

605-
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
606577
ggml_tensor * build_attn(
607578
llm_graph_input_attn_kv_unified_iswa * inp,
608579
ggml_cgraph * gf,
609580
ggml_tensor * wo,
610581
ggml_tensor * wo_b,
611582
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
612-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
584+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
614585
ggml_tensor * kq_b,
615586
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
616587
float kq_scale,
@@ -631,62 +602,23 @@ struct llm_graph_context {
631602
float kq_scale,
632603
int il) const;
633604

634-
ggml_tensor * build_attn(
635-
llm_graph_input_mem_hybrid * inp,
636-
ggml_cgraph * gf,
637-
ggml_tensor * wo,
638-
ggml_tensor * wo_b,
639-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642-
ggml_tensor * kq_b,
643-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644-
float kq_scale,
645-
int il) const;
646605
//
647606
// recurrent
648607
//
649608

650-
// TODO: avoid notion of "kv"
651-
// TODO: move this implementation to llama_memory_recurrent.
652-
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
653-
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
654-
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
655-
// `llama_memory_recurrent`
656-
ggml_tensor * build_rs(
657-
ggml_cgraph * gf,
658-
ggml_tensor * s,
659-
ggml_tensor * state_copy,
660-
int32_t state_size,
661-
int32_t n_seqs,
662-
uint32_t n_kv,
663-
uint32_t kv_head,
664-
uint32_t kv_size,
665-
int32_t rs_zero,
666-
bool avoid_copies = false) const;
667-
668-
llm_graph_input_rs * build_rs_inp() const;
669-
670-
ggml_tensor * build_rs(
671-
llm_graph_input_rs * inp,
672-
ggml_cgraph * gf,
673-
ggml_tensor * s,
674-
int32_t state_size,
675-
int32_t n_seqs,
676-
bool avoid_copies = false) const;
677-
678-
ggml_tensor * build_rs(
679-
llm_graph_input_mem_hybrid * inp,
680-
ggml_cgraph * gf,
681-
ggml_tensor * s,
682-
int32_t state_size,
683-
int32_t n_seqs,
684-
bool avoid_copies = false) const;
609+
ggml_tensor * build_copy_mask_state(
610+
ggml_cgraph * gf,
611+
ggml_tensor * s,
612+
ggml_tensor * state_copy,
613+
ggml_tensor * state_mask,
614+
int32_t n_state,
615+
int32_t n_seqs) const;
685616

686617
ggml_tensor * build_rwkv_token_shift_load(
687-
llm_graph_input_rs * inp,
688-
ggml_cgraph * gf,
689-
const llama_ubatch & ubatch,
618+
ggml_cgraph * gf,
619+
ggml_tensor * state_copy,
620+
ggml_tensor * state_mask,
621+
const llama_ubatch & ubatch,
690622
int il) const;
691623

692624
ggml_tensor * build_rwkv_token_shift_store(

0 commit comments

Comments
 (0)