@@ -171,7 +171,7 @@ struct llama_context : public llama_graph_i {
171
171
// graph
172
172
//
173
173
174
- // zero-out inputs and create the ctx_context for the compute graph
174
+ // zero-out inputs and create the ctx_compute for the compute graph
175
175
virtual ggml_cgraph * graph_init ();
176
176
177
177
// TODO: add encode/decode graphs
@@ -187,73 +187,74 @@ struct llama_context : public llama_graph_i {
187
187
188
188
ggml_context_ptr ctx_compute;
189
189
190
+ public:
190
191
//
191
- // graph build API (generic)
192
+ // graph build
192
193
//
193
194
194
195
virtual void build_cb (
195
196
ggml_tensor * cur,
196
197
const char * name,
197
198
const llama_ubatch & ubatch,
198
- int il);
199
+ int il) override ;
199
200
200
201
// apply control vector for layer il
201
202
virtual ggml_tensor * build_cvec (
202
203
ggml_context * ctx0,
203
204
ggml_tensor * cur,
204
- int il);
205
+ int il) override ;
205
206
206
207
// do mat_mul, while optionally apply lora
207
208
virtual ggml_tensor * build_lora_mm (
208
209
ggml_context * ctx0,
209
210
ggml_tensor * w,
210
- ggml_tensor * cur);
211
+ ggml_tensor * cur) override ;
211
212
212
213
// do mat_mul_id, while optionally apply lora
213
214
virtual ggml_tensor * build_lora_mm_id (
214
215
ggml_context * ctx0,
215
216
ggml_tensor * w, // struct ggml_tensor * as
216
217
ggml_tensor * cur, // struct ggml_tensor * b
217
- ggml_tensor * ids);
218
+ ggml_tensor * ids) override ;
218
219
219
- virtual ggml_tensor * build_rope_factors (int il);
220
+ virtual ggml_tensor * build_rope_factors (int il) override ;
220
221
221
222
virtual ggml_tensor * build_rope_shift (
222
223
ggml_context * ctx0,
223
224
ggml_tensor * cur,
224
225
ggml_tensor * shift,
225
226
ggml_tensor * factors,
226
- ggml_backend_buffer * bbuf);
227
+ ggml_backend_buffer * bbuf) override ;
227
228
228
229
virtual ggml_tensor * build_inp_embd (
229
230
ggml_context * ctx0,
230
231
ggml_tensor * tok_embd,
231
- const llama_ubatch & ubatch);
232
+ const llama_ubatch & ubatch) override ;
232
233
233
234
virtual ggml_tensor * build_inp_pos (
234
235
ggml_context * ctx0,
235
- int32_t n_tokens);
236
+ int32_t n_tokens) override ;
236
237
237
238
virtual ggml_tensor * build_inp_pos_bucket (
238
239
ggml_context * ctx0,
239
- int32_t n_tokens);
240
+ int32_t n_tokens) override ;
240
241
241
242
virtual ggml_tensor * build_inp_out_ids (
242
- ggml_context * ctx0);
243
+ ggml_context * ctx0) override ;
243
244
244
245
virtual ggml_tensor * build_inp_mean (
245
246
ggml_context * ctx0,
246
- int32_t n_tokens);
247
+ int32_t n_tokens) override ;
247
248
248
249
virtual ggml_tensor * build_inp_cls (
249
250
ggml_context * ctx0,
250
- int32_t n_tokens);
251
+ int32_t n_tokens) override ;
251
252
252
253
virtual void build_attn_inp (
253
254
ggml_context * ctx0,
254
255
int32_t n_tokens,
255
256
bool causal,
256
- bool swa);
257
+ bool swa) override ;
257
258
258
259
virtual ggml_tensor * build_attn (
259
260
ggml_context * ctx0,
@@ -266,7 +267,17 @@ struct llama_context : public llama_graph_i {
266
267
ggml_tensor * kq_b,
267
268
int32_t n_tokens,
268
269
float kq_scale,
269
- int il);
270
+ int il) override ;
271
+
272
+ protected:
273
+ virtual void build_kv_self_shift (
274
+ ggml_context * ctx0,
275
+ ggml_cgraph * gf);
276
+
277
+ // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
278
+ virtual void build_kv_self_defrag (
279
+ ggml_context * ctx0,
280
+ ggml_cgraph * gf);
270
281
271
282
public:
272
283
//
@@ -434,6 +445,7 @@ class llama_context_kv_self : public llama_context {
434
445
435
446
virtual ggml_cgraph * graph_init () override ;
436
447
448
+ public:
437
449
//
438
450
// graph build
439
451
//
@@ -463,6 +475,7 @@ class llama_context_kv_self : public llama_context {
463
475
float kq_scale,
464
476
int il) override ;
465
477
478
+ protected:
466
479
virtual void build_kv_self_shift (
467
480
ggml_context * ctx0,
468
481
ggml_cgraph * gf) override ;
@@ -548,6 +561,7 @@ class llama_context_recurrent : public llama_context {
548
561
549
562
virtual ggml_cgraph * graph_init () override ;
550
563
564
+ public:
551
565
//
552
566
// graph build
553
567
//
@@ -600,6 +614,7 @@ class llama_context_recurrent : public llama_context {
600
614
const llama_ubatch & ubatch,
601
615
int il) override ;
602
616
617
+ protected:
603
618
//
604
619
// state save/load
605
620
//
0 commit comments