@@ -140,6 +140,7 @@ struct llama_context : public llama_graph_i {
140
140
141
141
virtual void input_set (const llama_ubatch & ubatch);
142
142
143
+ private:
143
144
struct {
144
145
// base input tensors
145
146
ggml_tensor * tokens; // I32 [n_batch]
@@ -155,6 +156,7 @@ struct llama_context : public llama_graph_i {
155
156
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
156
157
} inp;
157
158
159
+ protected:
158
160
//
159
161
// output
160
162
//
@@ -192,71 +194,71 @@ struct llama_context : public llama_graph_i {
192
194
// graph build
193
195
//
194
196
195
- virtual void build_cb (
197
+ void build_cb (
196
198
ggml_tensor * cur,
197
199
const char * name,
198
200
const llama_ubatch & ubatch,
199
201
int il) override ;
200
202
201
203
// apply control vector for layer il
202
- virtual ggml_tensor * build_cvec (
204
+ ggml_tensor * build_cvec (
203
205
ggml_context * ctx0,
204
206
ggml_tensor * cur,
205
207
int il) override ;
206
208
207
209
// do mat_mul, while optionally apply lora
208
- virtual ggml_tensor * build_lora_mm (
210
+ ggml_tensor * build_lora_mm (
209
211
ggml_context * ctx0,
210
212
ggml_tensor * w,
211
213
ggml_tensor * cur) override ;
212
214
213
215
// do mat_mul_id, while optionally apply lora
214
- virtual ggml_tensor * build_lora_mm_id (
216
+ ggml_tensor * build_lora_mm_id (
215
217
ggml_context * ctx0,
216
218
ggml_tensor * w, // struct ggml_tensor * as
217
219
ggml_tensor * cur, // struct ggml_tensor * b
218
220
ggml_tensor * ids) override ;
219
221
220
- virtual ggml_tensor * build_rope_factors (int il) override ;
222
+ ggml_tensor * build_rope_factors (int il) override ;
221
223
222
- virtual ggml_tensor * build_rope_shift (
224
+ ggml_tensor * build_rope_shift (
223
225
ggml_context * ctx0,
224
226
ggml_tensor * cur,
225
227
ggml_tensor * shift,
226
228
ggml_tensor * factors,
227
229
ggml_backend_buffer * bbuf) override ;
228
230
229
- virtual ggml_tensor * build_inp_embd (
231
+ ggml_tensor * build_inp_embd (
230
232
ggml_context * ctx0,
231
233
ggml_tensor * tok_embd,
232
234
const llama_ubatch & ubatch) override ;
233
235
234
- virtual ggml_tensor * build_inp_pos (
236
+ ggml_tensor * build_inp_pos (
235
237
ggml_context * ctx0,
236
238
int32_t n_tokens) override ;
237
239
238
- virtual ggml_tensor * build_inp_pos_bucket (
240
+ ggml_tensor * build_inp_pos_bucket (
239
241
ggml_context * ctx0,
240
242
int32_t n_tokens) override ;
241
243
242
- virtual ggml_tensor * build_inp_out_ids (
244
+ ggml_tensor * build_inp_out_ids (
243
245
ggml_context * ctx0) override ;
244
246
245
- virtual ggml_tensor * build_inp_mean (
247
+ ggml_tensor * build_inp_mean (
246
248
ggml_context * ctx0,
247
249
int32_t n_tokens) override ;
248
250
249
- virtual ggml_tensor * build_inp_cls (
251
+ ggml_tensor * build_inp_cls (
250
252
ggml_context * ctx0,
251
253
int32_t n_tokens) override ;
252
254
253
- virtual void build_attn_inp (
255
+ void build_attn_inp (
254
256
ggml_context * ctx0,
255
257
int32_t n_tokens,
256
258
bool causal,
257
259
bool swa) override ;
258
260
259
- virtual ggml_tensor * build_attn (
261
+ ggml_tensor * build_attn (
260
262
ggml_context * ctx0,
261
263
ggml_cgraph * gf,
262
264
ggml_tensor * wo,
@@ -270,6 +272,9 @@ struct llama_context : public llama_graph_i {
270
272
int il) override ;
271
273
272
274
protected:
275
+ virtual ggml_tensor * build_inp_self_k_shift (
276
+ ggml_context * ctx0);
277
+
273
278
virtual void build_kv_self_shift (
274
279
ggml_context * ctx0,
275
280
ggml_cgraph * gf);
@@ -288,6 +293,7 @@ struct llama_context : public llama_graph_i {
288
293
virtual void perf_reset ();
289
294
290
295
protected:
296
+ // TODO: become private
291
297
mutable int64_t t_start_us = 0 ;
292
298
mutable int64_t t_load_us = 0 ;
293
299
mutable int64_t t_p_eval_us = 0 ;
@@ -346,6 +352,7 @@ struct llama_context : public llama_graph_i {
346
352
//
347
353
// members
348
354
//
355
+ // TODO: become private / move to llama_graph_i
349
356
350
357
const llama_model & model;
351
358
@@ -412,24 +419,25 @@ class llama_context_kv_self : public llama_context {
412
419
virtual ~llama_context_kv_self ();
413
420
414
421
protected:
415
- virtual void reserve () override ;
422
+ void reserve () override ;
416
423
417
424
public:
418
- virtual llama_kv_cache * get_kv_self () override ;
419
- virtual const llama_kv_cache * get_kv_self () const override ;
425
+ llama_kv_cache * get_kv_self () override ;
426
+ const llama_kv_cache * get_kv_self () const override ;
420
427
421
- virtual void kv_self_update () override ;
428
+ void kv_self_update () override ;
422
429
423
- virtual int encode (llama_batch & inp_batch) override ;
424
- virtual int decode (llama_batch & inp_batch) override ;
430
+ int encode (llama_batch & inp_batch) override ;
431
+ int decode (llama_batch & inp_batch) override ;
425
432
426
433
protected:
427
434
//
428
435
// input
429
436
//
430
437
431
- virtual void input_set (const llama_ubatch & ubatch) override ;
438
+ void input_set (const llama_ubatch & ubatch) override ;
432
439
440
+ private:
433
441
struct {
434
442
ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
435
443
ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
@@ -443,26 +451,24 @@ class llama_context_kv_self : public llama_context {
443
451
// graph
444
452
//
445
453
446
- virtual ggml_cgraph * graph_init () override ;
454
+ ggml_cgraph * graph_init () override ;
447
455
448
456
public:
449
457
//
450
458
// graph build
451
459
//
452
460
453
- virtual ggml_tensor * build_inp_self_k_shift (ggml_context * ctx0) override ;
454
-
455
- virtual ggml_tensor * build_inp_pos_bucket (
461
+ ggml_tensor * build_inp_pos_bucket (
456
462
ggml_context * ctx0,
457
463
int32_t n_tokens) override ;
458
464
459
- virtual void build_attn_inp (
465
+ void build_attn_inp (
460
466
ggml_context * ctx0,
461
467
int32_t n_tokens,
462
468
bool causal,
463
469
bool swa) override ;
464
470
465
- virtual ggml_tensor * build_attn (
471
+ ggml_tensor * build_attn (
466
472
ggml_context * ctx0,
467
473
ggml_cgraph * gf,
468
474
ggml_tensor * wo,
@@ -476,16 +482,22 @@ class llama_context_kv_self : public llama_context {
476
482
int il) override ;
477
483
478
484
protected:
479
- virtual void build_kv_self_shift (
485
+ ggml_tensor * build_inp_self_k_shift (ggml_context * ctx0) override ;
486
+
487
+ void build_kv_self_shift (
480
488
ggml_context * ctx0,
481
489
ggml_cgraph * gf) override ;
482
490
483
491
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
484
- virtual void build_kv_self_defrag (
492
+ void build_kv_self_defrag (
485
493
ggml_context * ctx0,
486
494
ggml_cgraph * gf) override ;
487
495
496
+ // =======================================================
488
497
// === encoder-decoder ===
498
+ //
499
+ // TODO: this is temporary here, it will be moved
500
+ //
489
501
490
502
// whether we are computing encoder output or decoder output
491
503
bool is_encoding = false ;
@@ -497,23 +509,25 @@ class llama_context_kv_self : public llama_context {
497
509
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
498
510
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
499
511
500
- virtual ggml_tensor * build_inp_embd_enc (
512
+ ggml_tensor * build_inp_embd_enc (
501
513
ggml_context * ctx0) override ;
502
514
503
- virtual ggml_tensor * build_inp_kq_mask_cross (
515
+ ggml_tensor * build_inp_kq_mask_cross (
504
516
ggml_context * ctx0,
505
517
int32_t n_tokens) override ;
518
+ // ======================================================
506
519
507
520
//
508
521
// state save/load
509
522
//
510
523
511
- virtual size_t state_get_data (llama_io_write_i & io) override ;
512
- virtual size_t state_set_data (llama_io_read_i & io) override ;
524
+ size_t state_get_data (llama_io_write_i & io) override ;
525
+ size_t state_set_data (llama_io_read_i & io) override ;
513
526
514
- virtual size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
515
- virtual size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
527
+ size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
528
+ size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
516
529
530
+ private:
517
531
//
518
532
// members
519
533
//
@@ -532,24 +546,25 @@ class llama_context_recurrent : public llama_context {
532
546
virtual ~llama_context_recurrent ();
533
547
534
548
protected:
535
- virtual void reserve () override ;
549
+ void reserve () override ;
536
550
537
551
public:
538
- virtual llama_kv_cache * get_kv_self () override ;
539
- virtual const llama_kv_cache * get_kv_self () const override ;
552
+ llama_kv_cache * get_kv_self () override ;
553
+ const llama_kv_cache * get_kv_self () const override ;
540
554
541
- virtual void kv_self_update () override ;
555
+ void kv_self_update () override ;
542
556
543
- virtual int encode (llama_batch & inp_batch) override ;
544
- virtual int decode (llama_batch & inp_batch) override ;
557
+ int encode (llama_batch & inp_batch) override ;
558
+ int decode (llama_batch & inp_batch) override ;
545
559
546
560
protected:
547
561
//
548
562
// input
549
563
//
550
564
551
- virtual void input_set (const llama_ubatch & ubatch) override ;
565
+ void input_set (const llama_ubatch & ubatch) override ;
552
566
567
+ private:
553
568
struct {
554
569
ggml_tensor * s_copy; // I32 [kv_size]
555
570
ggml_tensor * s_mask; // F32 [1, n_kv]
@@ -559,20 +574,20 @@ class llama_context_recurrent : public llama_context {
559
574
// graph
560
575
//
561
576
562
- virtual ggml_cgraph * graph_init () override ;
577
+ ggml_cgraph * graph_init () override ;
563
578
564
579
public:
565
580
//
566
581
// graph build
567
582
//
568
583
569
- virtual ggml_tensor * build_inp_s_copy (
584
+ ggml_tensor * build_inp_s_copy (
570
585
ggml_context * ctx0) override ;
571
586
572
- virtual ggml_tensor * build_inp_s_mask (
587
+ ggml_tensor * build_inp_s_mask (
573
588
ggml_context * ctx0) override ;
574
589
575
- virtual ggml_tensor * build_copy_mask_state (
590
+ ggml_tensor * build_copy_mask_state (
576
591
ggml_context * ctx0,
577
592
ggml_cgraph * gf,
578
593
ggml_tensor * s,
@@ -581,7 +596,7 @@ class llama_context_recurrent : public llama_context {
581
596
int32_t n_state,
582
597
int32_t n_seqs) override ;
583
598
584
- virtual ggml_tensor * build_mamba_layer (
599
+ ggml_tensor * build_mamba_layer (
585
600
ggml_context * ctx0,
586
601
ggml_cgraph * gf,
587
602
ggml_tensor * cur,
@@ -590,21 +605,21 @@ class llama_context_recurrent : public llama_context {
590
605
const llama_ubatch & ubatch,
591
606
int il) override ;
592
607
593
- virtual ggml_tensor * build_rwkv_token_shift_load (
608
+ ggml_tensor * build_rwkv_token_shift_load (
594
609
ggml_context * ctx0,
595
610
ggml_cgraph * gf,
596
611
ggml_tensor * state_copy,
597
612
ggml_tensor * state_mask,
598
613
const llama_ubatch & ubatch,
599
614
int il) override ;
600
615
601
- virtual ggml_tensor * build_rwkv_token_shift_store (
616
+ ggml_tensor * build_rwkv_token_shift_store (
602
617
ggml_context * ctx0,
603
618
ggml_tensor * token_shift,
604
619
const llama_ubatch & ubatch,
605
620
int il) override ;
606
621
607
- virtual ggml_tensor * build_rwkv6_time_mix (
622
+ ggml_tensor * build_rwkv6_time_mix (
608
623
ggml_context * ctx0,
609
624
ggml_cgraph * gf,
610
625
ggml_tensor * cur,
@@ -619,12 +634,13 @@ class llama_context_recurrent : public llama_context {
619
634
// state save/load
620
635
//
621
636
622
- virtual size_t state_get_data (llama_io_write_i & io) override ;
623
- virtual size_t state_set_data (llama_io_read_i & io) override ;
637
+ size_t state_get_data (llama_io_write_i & io) override ;
638
+ size_t state_set_data (llama_io_read_i & io) override ;
624
639
625
- virtual size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
626
- virtual size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
640
+ size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
641
+ size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
627
642
643
+ private:
628
644
//
629
645
// members
630
646
//
@@ -646,7 +662,7 @@ class llama_context_enc_dec : public llama_context {
646
662
647
663
virtual ~llama_context_enc_dec ();
648
664
649
- protected :
665
+ private :
650
666
llama_context_kv_self ctx_dec;
651
667
};
652
668
0 commit comments