@@ -17,12 +17,11 @@ struct ggml_tensor;
17
17
struct llama_ubatch ;
18
18
struct llama_cparams ;
19
19
20
- struct llama_memory_context_i ;
20
+ struct llama_memory_state_i ;
21
21
22
- class llama_kv_cache_unified_context ;
23
- class llama_kv_cache_unified_iswa_context ;
24
- class llama_memory_recurrent_context ;
25
- class llama_memory_hybrid_context ;
22
+ class llama_kv_cache_unified_state ;
23
+ class llama_kv_cache_unified_iswa_state ;
24
+ class llama_kv_cache_recurrent_state ;
26
25
27
26
// certain models (typically multi-modal) can produce different types of graphs
28
27
enum llm_graph_type {
@@ -38,7 +37,6 @@ enum llm_ffn_op_type {
38
37
LLM_FFN_RELU_SQR,
39
38
LLM_FFN_SWIGLU,
40
39
LLM_FFN_GEGLU,
41
- LLM_FFN_REGLU,
42
40
};
43
41
44
42
enum llm_ffn_gate_type {
@@ -96,14 +94,14 @@ class llm_graph_input_embd : public llm_graph_input_i {
96
94
97
95
class llm_graph_input_pos : public llm_graph_input_i {
98
96
public:
99
- llm_graph_input_pos (uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
97
+ llm_graph_input_pos (int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
100
98
virtual ~llm_graph_input_pos () = default ;
101
99
102
100
void set_input (const llama_ubatch * ubatch) override ;
103
101
104
102
ggml_tensor * pos = nullptr ; // I32 [n_batch]
105
103
106
- const uint32_t n_pos_per_embd = 1 ;
104
+ const int64_t n_pos_per_embd = 1 ;
107
105
};
108
106
109
107
// temperature tuning, used by llama4
@@ -137,16 +135,15 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
137
135
public:
138
136
llm_graph_input_pos_bucket_kv (
139
137
const llama_hparams & hparams,
140
- const llama_kv_cache_unified_context * mctx ) : hparams(hparams), mctx(mctx ) {}
138
+ const llama_kv_cache_unified_state * kv_state ) : hparams(hparams), kv_state(kv_state ) {}
141
139
virtual ~llm_graph_input_pos_bucket_kv () = default ;
142
140
143
141
void set_input (const llama_ubatch * ubatch) override ;
144
142
145
143
ggml_tensor * pos_bucket = nullptr ; // I32 [n_kv, n_batch]
146
144
147
145
const llama_hparams & hparams;
148
-
149
- const llama_kv_cache_unified_context * mctx;
146
+ const llama_kv_cache_unified_state * kv_state;
150
147
};
151
148
152
149
class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,16 +188,28 @@ class llm_graph_input_cls : public llm_graph_input_i {
191
188
const llama_cparams & cparams;
192
189
};
193
190
194
- class llm_graph_input_rs : public llm_graph_input_i {
191
+ class llm_graph_input_s_copy : public llm_graph_input_i {
195
192
public:
196
- llm_graph_input_rs (const llama_memory_recurrent_context * mctx ) : mctx(mctx ) {}
197
- virtual ~llm_graph_input_rs () = default ;
193
+ llm_graph_input_s_copy (const llama_kv_cache_recurrent_state * kv_state ) : kv_state(kv_state ) {}
194
+ virtual ~llm_graph_input_s_copy () = default ;
198
195
199
196
void set_input (const llama_ubatch * ubatch) override ;
200
197
201
198
ggml_tensor * s_copy; // I32 [kv_size]
202
199
203
- const llama_memory_recurrent_context * mctx;
200
+ const llama_kv_cache_recurrent_state * kv_state;
201
+ };
202
+
203
+ class llm_graph_input_s_mask : public llm_graph_input_i {
204
+ public:
205
+ llm_graph_input_s_mask (const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
206
+ virtual ~llm_graph_input_s_mask () = default ;
207
+
208
+ void set_input (const llama_ubatch * ubatch) override ;
209
+
210
+ ggml_tensor * s_mask; // F32 [1, n_kv]
211
+
212
+ const llama_kv_cache_recurrent_state * kv_state;
204
213
};
205
214
206
215
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -240,10 +249,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
240
249
llm_graph_input_attn_kv_unified (
241
250
const llama_hparams & hparams,
242
251
const llama_cparams & cparams,
243
- const llama_kv_cache_unified_context * mctx ) :
252
+ const llama_kv_cache_unified_state * kv_state ) :
244
253
hparams (hparams),
245
254
cparams (cparams),
246
- mctx (mctx ) {
255
+ kv_state (kv_state ) {
247
256
}
248
257
~llm_graph_input_attn_kv_unified () = default ;
249
258
@@ -257,18 +266,18 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257
266
const llama_hparams & hparams;
258
267
const llama_cparams & cparams;
259
268
260
- const llama_kv_cache_unified_context * mctx ;
269
+ const llama_kv_cache_unified_state * kv_state ;
261
270
};
262
271
263
272
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
264
273
public:
265
274
llm_graph_input_attn_kv_unified_iswa (
266
275
const llama_hparams & hparams,
267
276
const llama_cparams & cparams,
268
- const llama_kv_cache_unified_iswa_context * mctx ) :
277
+ const llama_kv_cache_unified_iswa_state * kv_state ) :
269
278
hparams (hparams),
270
279
cparams (cparams),
271
- mctx (mctx ) {
280
+ kv_state (kv_state ) {
272
281
}
273
282
~llm_graph_input_attn_kv_unified_iswa () = default ;
274
283
@@ -285,7 +294,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
285
294
const llama_hparams & hparams;
286
295
const llama_cparams & cparams;
287
296
288
- const llama_kv_cache_unified_iswa_context * mctx ;
297
+ const llama_kv_cache_unified_iswa_state * kv_state ;
289
298
};
290
299
291
300
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -303,44 +312,6 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
303
312
const llama_cross * cross = nullptr ;
304
313
};
305
314
306
- class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307
- public:
308
- llm_graph_input_mem_hybrid (
309
- const llama_hparams & hparams,
310
- const llama_cparams & cparams,
311
- const llama_memory_hybrid_context * mctx) :
312
- hparams (hparams),
313
- cparams (cparams),
314
- mctx (mctx) {
315
- }
316
- virtual ~llm_graph_input_mem_hybrid () = default ;
317
-
318
- void set_input (const llama_ubatch * ubatch) override ;
319
-
320
- ggml_tensor * s_copy; // I32 [kv_size]
321
-
322
- ggml_tensor * get_kq_mask () const { return self_kq_mask_cnv; }
323
-
324
- ggml_tensor * self_kq_mask = nullptr ; // F32 [n_kv, n_batch]
325
- ggml_tensor * self_kq_mask_cnv = nullptr ; // [n_kv, n_batch]
326
-
327
- const llama_hparams & hparams;
328
- const llama_cparams & cparams;
329
-
330
- const llama_memory_hybrid_context * mctx;
331
- };
332
-
333
- // TODO: remove this when ggml_scale_add is implemented
334
- class llm_graph_input_one : public llm_graph_input_i {
335
- public:
336
- llm_graph_input_one () {}
337
- virtual ~llm_graph_input_one () = default ;
338
-
339
- void set_input (const llama_ubatch *) override ;
340
-
341
- ggml_tensor * one = nullptr ; // F32
342
- };
343
-
344
315
//
345
316
// llm_graph_result
346
317
//
@@ -414,12 +385,12 @@ struct llm_graph_params {
414
385
ggml_backend_sched_t sched;
415
386
ggml_backend_t backend_cpu;
416
387
417
- const llama_adapter_cvec * cvec;
418
- const llama_adapter_loras * loras;
419
- const llama_memory_context_i * mctx ;
420
- const llama_cross * cross;
388
+ const llama_adapter_cvec * cvec;
389
+ const llama_adapter_loras * loras;
390
+ const llama_memory_state_i * mstate ;
391
+ const llama_cross * cross;
421
392
422
- uint32_t n_outputs;
393
+ int32_t n_outputs;
423
394
424
395
const llm_graph_cb & cb;
425
396
};
@@ -453,8 +424,8 @@ struct llm_graph_context {
453
424
const float norm_eps;
454
425
const float norm_rms_eps;
455
426
456
- const int64_t n_tokens;
457
- const int64_t n_outputs;
427
+ const int32_t n_tokens;
428
+ const int32_t n_outputs;
458
429
const int32_t n_ctx_orig; // yarn
459
430
460
431
const enum llama_pooling_type pooling_type;
@@ -466,17 +437,18 @@ struct llm_graph_context {
466
437
467
438
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
468
439
469
- const llama_adapter_cvec * cvec;
470
- const llama_adapter_loras * loras;
471
- const llama_memory_context_i * mctx ;
472
- const llama_cross * cross;
440
+ const llama_adapter_cvec * cvec;
441
+ const llama_adapter_loras * loras;
442
+ const llama_memory_state_i * mstate ;
443
+ const llama_cross * cross;
473
444
474
445
const llm_graph_cb & cb_func;
475
446
476
447
std::unique_ptr<llm_graph_result> res;
477
448
478
449
llm_graph_context (const llm_graph_params & params);
479
- virtual ~llm_graph_context () = default ;
450
+
451
+ int64_t n_pos_per_embd () const ;
480
452
481
453
void cb (ggml_tensor * cur, const char * name, int il) const ;
482
454
@@ -548,14 +520,14 @@ struct llm_graph_context {
548
520
ggml_tensor * build_inp_out_ids () const ;
549
521
ggml_tensor * build_inp_mean () const ;
550
522
ggml_tensor * build_inp_cls () const ;
523
+ ggml_tensor * build_inp_s_copy () const ;
524
+ ggml_tensor * build_inp_s_mask () const ;
551
525
552
526
ggml_tensor * build_inp_cross_embd () const ;
553
527
ggml_tensor * build_inp_pos_bucket_enc () const ;
554
528
ggml_tensor * build_inp_pos_bucket_dec () const ;
555
529
ggml_tensor * build_pos_bias (ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const ;
556
530
557
- llm_graph_input_mem_hybrid * build_inp_mem_hybrid () const ;
558
-
559
531
//
560
532
// attention
561
533
//
@@ -602,15 +574,14 @@ struct llm_graph_context {
602
574
603
575
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa () const ;
604
576
605
- // note: if k_cur or v_cur are not provided, they will not be stored in the memory
606
577
ggml_tensor * build_attn (
607
578
llm_graph_input_attn_kv_unified_iswa * inp,
608
579
ggml_cgraph * gf,
609
580
ggml_tensor * wo,
610
581
ggml_tensor * wo_b,
611
582
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
612
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
584
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
614
585
ggml_tensor * kq_b,
615
586
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
616
587
float kq_scale,
@@ -631,62 +602,23 @@ struct llm_graph_context {
631
602
float kq_scale,
632
603
int il) const ;
633
604
634
- ggml_tensor * build_attn (
635
- llm_graph_input_mem_hybrid * inp,
636
- ggml_cgraph * gf,
637
- ggml_tensor * wo,
638
- ggml_tensor * wo_b,
639
- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642
- ggml_tensor * kq_b,
643
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644
- float kq_scale,
645
- int il) const ;
646
605
//
647
606
// recurrent
648
607
//
649
608
650
- // TODO: avoid notion of "kv"
651
- // TODO: move this implementation to llama_memory_recurrent.
652
- // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
653
- // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
654
- // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
655
- // `llama_memory_recurrent`
656
- ggml_tensor * build_rs (
657
- ggml_cgraph * gf,
658
- ggml_tensor * s,
659
- ggml_tensor * state_copy,
660
- int32_t state_size,
661
- int32_t n_seqs,
662
- uint32_t n_kv,
663
- uint32_t kv_head,
664
- uint32_t kv_size,
665
- int32_t rs_zero,
666
- bool avoid_copies = false ) const ;
667
-
668
- llm_graph_input_rs * build_rs_inp () const ;
669
-
670
- ggml_tensor * build_rs (
671
- llm_graph_input_rs * inp,
672
- ggml_cgraph * gf,
673
- ggml_tensor * s,
674
- int32_t state_size,
675
- int32_t n_seqs,
676
- bool avoid_copies = false ) const ;
677
-
678
- ggml_tensor * build_rs (
679
- llm_graph_input_mem_hybrid * inp,
680
- ggml_cgraph * gf,
681
- ggml_tensor * s,
682
- int32_t state_size,
683
- int32_t n_seqs,
684
- bool avoid_copies = false ) const ;
609
+ ggml_tensor * build_copy_mask_state (
610
+ ggml_cgraph * gf,
611
+ ggml_tensor * s,
612
+ ggml_tensor * state_copy,
613
+ ggml_tensor * state_mask,
614
+ int32_t n_state,
615
+ int32_t n_seqs) const ;
685
616
686
617
ggml_tensor * build_rwkv_token_shift_load (
687
- llm_graph_input_rs * inp,
688
- ggml_cgraph * gf,
689
- const llama_ubatch & ubatch,
618
+ ggml_cgraph * gf,
619
+ ggml_tensor * state_copy,
620
+ ggml_tensor * state_mask,
621
+ const llama_ubatch & ubatch,
690
622
int il) const ;
691
623
692
624
ggml_tensor * build_rwkv_token_shift_store (
0 commit comments