Skip to content

Commit 8f974bc

Browse files
authored
graph : refactor context to not pass gf explicitly (#14629)
ggml-ci
1 parent 09651d0 commit 8f974bc

File tree

5 files changed

+295
-341
lines changed

5 files changed

+295
-341
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ bool llama_context::apply_adapter_cvec(
694694
return cvec.apply(model, data, len, n_embd, il_start, il_end);
695695
}
696696

697-
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
697+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
698698
if (mctx && !mctx->apply()) {
699699
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
700700
ret = GGML_STATUS_FAILED;
@@ -1363,7 +1363,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13631363
}
13641364

13651365
llm_graph_params llama_context::graph_params(
1366-
llm_graph_result_i * res,
1366+
llm_graph_result * res,
13671367
const llama_ubatch & ubatch,
13681368
const llama_memory_context_i * mctx,
13691369
llm_graph_type gtype) const {

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct llama_context {
9494
// if memory_context is provided, it will be applied first to the context's memory
9595
// ret contains the status of the graph computation
9696
// returns nullptr only if ret != GGML_STATUS_SUCCESS
97-
llm_graph_result_i * process_ubatch(
97+
llm_graph_result * process_ubatch(
9898
const llama_ubatch & ubatch,
9999
llm_graph_type gtype,
100100
llama_memory_context_i * mctx,
@@ -199,7 +199,7 @@ struct llama_context {
199199

200200
private:
201201
llm_graph_params graph_params(
202-
llm_graph_result_i * res,
202+
llm_graph_result * res,
203203
const llama_ubatch & ubatch,
204204
const llama_memory_context_i * mctx,
205205
llm_graph_type gtype) const;

src/llama-graph.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,10 @@ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
486486
return inputs.back().get();
487487
}
488488

489+
void llm_graph_result::set_params(const llm_graph_params & params) {
490+
this->params = params;
491+
}
492+
489493
//
490494
// llm_graph_context
491495
//
@@ -527,9 +531,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
527531
mctx (params.mctx),
528532
cross (params.cross),
529533
cb_func (params.cb),
530-
res (static_cast<llm_graph_result *>(params.res)),
531-
ctx0 (res->get_ctx()) {
532-
res->params = params;
534+
res (params.res),
535+
ctx0 (res->get_ctx()),
536+
gf (res->get_gf()) {
537+
res->set_params(params);
533538
}
534539

535540
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1119,7 +1124,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
11191124
}
11201125

11211126
ggml_tensor * llm_graph_context::build_attn_mha(
1122-
ggml_cgraph * gf,
11231127
ggml_tensor * q,
11241128
ggml_tensor * k,
11251129
ggml_tensor * v,
@@ -1253,7 +1257,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
12531257

12541258
ggml_tensor * llm_graph_context::build_attn(
12551259
llm_graph_input_attn_no_cache * inp,
1256-
ggml_cgraph * gf,
12571260
ggml_tensor * wo,
12581261
ggml_tensor * wo_b,
12591262
ggml_tensor * q_cur,
@@ -1281,7 +1284,7 @@ ggml_tensor * llm_graph_context::build_attn(
12811284
ggml_tensor * k = k_cur;
12821285
ggml_tensor * v = v_cur;
12831286

1284-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1287+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12851288
cb(cur, "kqv_out", il);
12861289

12871290
if (wo) {
@@ -1337,7 +1340,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
13371340

13381341
ggml_tensor * llm_graph_context::build_attn(
13391342
llm_graph_input_attn_kv_unified * inp,
1340-
ggml_cgraph * gf,
13411343
ggml_tensor * wo,
13421344
ggml_tensor * wo_b,
13431345
ggml_tensor * q_cur,
@@ -1370,7 +1372,7 @@ ggml_tensor * llm_graph_context::build_attn(
13701372
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
13711373
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
13721374

1373-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1375+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13741376
cb(cur, "kqv_out", il);
13751377

13761378
if (wo) {
@@ -1390,7 +1392,6 @@ ggml_tensor * llm_graph_context::build_attn(
13901392

13911393
ggml_tensor * llm_graph_context::build_attn(
13921394
llm_graph_input_attn_kv_unified_iswa * inp,
1393-
ggml_cgraph * gf,
13941395
ggml_tensor * wo,
13951396
ggml_tensor * wo_b,
13961397
ggml_tensor * q_cur,
@@ -1437,7 +1438,7 @@ ggml_tensor * llm_graph_context::build_attn(
14371438
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
14381439
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
14391440

1440-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1441+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14411442
cb(cur, "kqv_out", il);
14421443

14431444
if (wo) {
@@ -1470,7 +1471,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
14701471

14711472
ggml_tensor * llm_graph_context::build_attn(
14721473
llm_graph_input_attn_cross * inp,
1473-
ggml_cgraph * gf,
14741474
ggml_tensor * wo,
14751475
ggml_tensor * wo_b,
14761476
ggml_tensor * q_cur,
@@ -1492,7 +1492,7 @@ ggml_tensor * llm_graph_context::build_attn(
14921492
ggml_tensor * k = k_cur;
14931493
ggml_tensor * v = v_cur;
14941494

1495-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1495+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14961496
cb(cur, "kqv_out", il);
14971497

14981498
if (wo) {
@@ -1550,7 +1550,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
15501550
}
15511551

15521552
ggml_tensor * llm_graph_context::build_rs(
1553-
ggml_cgraph * gf,
15541553
ggml_tensor * s,
15551554
ggml_tensor * state_copy,
15561555
int32_t state_size,
@@ -1608,21 +1607,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
16081607

16091608
ggml_tensor * llm_graph_context::build_rs(
16101609
llm_graph_input_rs * inp,
1611-
ggml_cgraph * gf,
16121610
ggml_tensor * s,
16131611
int32_t state_size,
16141612
int32_t n_seqs,
16151613
const llm_graph_get_rows_fn & get_state_rows) const {
16161614
const auto * kv_state = inp->mctx;
16171615

1618-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1616+
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
16191617
}
16201618

16211619
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16221620
llm_graph_input_rs * inp,
1623-
ggml_cgraph * gf,
16241621
const llama_ubatch & ubatch,
1625-
int il) const {
1622+
int il) const {
16261623
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
16271624

16281625
const auto token_shift_count = hparams.token_shift_count;
@@ -1632,7 +1629,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16321629
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
16331630

16341631
ggml_tensor * token_shift = build_rs(
1635-
inp, gf, token_shift_all,
1632+
inp, token_shift_all,
16361633
hparams.n_embd_r(), n_seqs);
16371634

16381635
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1672,7 +1669,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
16721669
}
16731670

16741671
void llm_graph_context::build_pooling(
1675-
ggml_cgraph * gf,
16761672
ggml_tensor * cls,
16771673
ggml_tensor * cls_b,
16781674
ggml_tensor * cls_out,

src/llama-graph.h

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372
// these are used by the llama_context to extact the relevant data, based on the compute parameters
373373

374-
// TODO: this interface seems redundant - remove it
375-
class llm_graph_result_i {
376-
public:
377-
virtual ~llm_graph_result_i() = default;
378-
379-
virtual ggml_tensor * get_tokens() const = 0;
380-
virtual ggml_tensor * get_logits() const = 0;
381-
virtual ggml_tensor * get_embd() const = 0;
382-
virtual ggml_tensor * get_embd_pooled() const = 0;
383-
384-
virtual ggml_cgraph * get_gf() = 0;
385-
virtual ggml_context * get_ctx() = 0;
386-
387-
virtual void reset() = 0;
388-
389-
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
390-
391-
virtual bool can_reuse(const llm_graph_params & params) = 0;
392-
};
393-
394-
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395-
396374
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398376

377+
class llm_graph_result;
378+
399379
struct llm_graph_params {
400380
llm_arch arch = LLM_ARCH_UNKNOWN;
401381

@@ -418,8 +398,7 @@ struct llm_graph_params {
418398

419399
llm_graph_cb cb;
420400

421-
// TODO: temporary
422-
llm_graph_result_i * res;
401+
llm_graph_result * res;
423402

424403
// return true if the "other" params would result in a graph with the same topology as with the current params
425404
// having the same topology allows us to reuse the graph in some cases
@@ -464,35 +443,37 @@ struct llm_graph_params {
464443
}
465444
};
466445

467-
class llm_graph_result : public llm_graph_result_i {
446+
class llm_graph_result {
468447
public:
469448
llm_graph_result(int64_t max_nodes);
470449

471450
virtual ~llm_graph_result() = default;
472451

473-
ggml_tensor * get_tokens() const override { return t_tokens; }
474-
ggml_tensor * get_logits() const override { return t_logits; }
475-
ggml_tensor * get_embd() const override { return t_embd; }
476-
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
452+
ggml_tensor * get_tokens() const { return t_tokens; }
453+
ggml_tensor * get_logits() const { return t_logits; }
454+
ggml_tensor * get_embd() const { return t_embd; }
455+
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
477456

478-
ggml_cgraph * get_gf() override { return gf; }
479-
ggml_context * get_ctx() override { return ctx_compute.get(); }
457+
ggml_cgraph * get_gf() const { return gf; }
458+
ggml_context * get_ctx() const { return ctx_compute.get(); }
480459

481460
int64_t get_max_nodes() const;
482461

483-
void reset() override;
462+
void reset();
484463

485-
void set_inputs(const llama_ubatch * ubatch) override;
464+
void set_inputs(const llama_ubatch * ubatch);
486465

487466
// try to update the existing graph result using the new graph parameters in order to reuse it
488467
// this can only be done if we determine that the resulting graph using the new graph parameters
489468
// would be identical to the existing graph. in that case, we simply have to update the memory
490469
// contexts of the input tensors of the graph and we can reuse it for another computation
491470
// return true if the graph was updated and can be reused
492-
bool can_reuse(const llm_graph_params & params) override;
471+
bool can_reuse(const llm_graph_params & params);
493472

494473
llm_graph_input_i * add_input(llm_graph_input_ptr input);
495474

475+
void set_params(const llm_graph_params & params);
476+
496477
// important graph nodes
497478
ggml_tensor * t_tokens = nullptr;
498479
ggml_tensor * t_logits = nullptr;
@@ -510,6 +491,7 @@ class llm_graph_result : public llm_graph_result_i {
510491

511492
int64_t max_nodes;
512493

494+
private:
513495
// keep a copy of the previous graph parameters
514496
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
515497
// note: these are updated after constructing the new graph
@@ -519,6 +501,8 @@ class llm_graph_result : public llm_graph_result_i {
519501
int debug = 0;
520502
};
521503

504+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
505+
522506
//
523507
// llm_graph_context
524508
//
@@ -576,6 +560,7 @@ struct llm_graph_context {
576560
llm_graph_result * res;
577561

578562
ggml_context * ctx0 = nullptr;
563+
ggml_cgraph * gf = nullptr;
579564

580565
llm_graph_context(const llm_graph_params & params);
581566
virtual ~llm_graph_context() = default;
@@ -661,7 +646,6 @@ struct llm_graph_context {
661646
//
662647

663648
ggml_tensor * build_attn_mha(
664-
ggml_cgraph * gf,
665649
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
666650
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
667651
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -674,7 +658,6 @@ struct llm_graph_context {
674658

675659
ggml_tensor * build_attn(
676660
llm_graph_input_attn_no_cache * inp,
677-
ggml_cgraph * gf,
678661
ggml_tensor * wo,
679662
ggml_tensor * wo_b,
680663
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -689,7 +672,6 @@ struct llm_graph_context {
689672

690673
ggml_tensor * build_attn(
691674
llm_graph_input_attn_kv_unified * inp,
692-
ggml_cgraph * gf,
693675
ggml_tensor * wo,
694676
ggml_tensor * wo_b,
695677
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -705,7 +687,6 @@ struct llm_graph_context {
705687
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
706688
ggml_tensor * build_attn(
707689
llm_graph_input_attn_kv_unified_iswa * inp,
708-
ggml_cgraph * gf,
709690
ggml_tensor * wo,
710691
ggml_tensor * wo_b,
711692
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -720,7 +701,6 @@ struct llm_graph_context {
720701

721702
ggml_tensor * build_attn(
722703
llm_graph_input_attn_cross * inp,
723-
ggml_cgraph * gf,
724704
ggml_tensor * wo,
725705
ggml_tensor * wo_b,
726706
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +722,6 @@ struct llm_graph_context {
742722
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
743723
// `llama_memory_recurrent`
744724
ggml_tensor * build_rs(
745-
ggml_cgraph * gf,
746725
ggml_tensor * s,
747726
ggml_tensor * state_copy,
748727
int32_t state_size,
@@ -757,17 +736,15 @@ struct llm_graph_context {
757736

758737
ggml_tensor * build_rs(
759738
llm_graph_input_rs * inp,
760-
ggml_cgraph * gf,
761739
ggml_tensor * s,
762740
int32_t state_size,
763741
int32_t n_seqs,
764742
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
765743

766744
ggml_tensor * build_rwkv_token_shift_load(
767745
llm_graph_input_rs * inp,
768-
ggml_cgraph * gf,
769746
const llama_ubatch & ubatch,
770-
int il) const;
747+
int il) const;
771748

772749
ggml_tensor * build_rwkv_token_shift_store(
773750
ggml_tensor * token_shift,
@@ -784,7 +761,6 @@ struct llm_graph_context {
784761
//
785762

786763
void build_pooling(
787-
ggml_cgraph * gf,
788764
ggml_tensor * cls,
789765
ggml_tensor * cls_b,
790766
ggml_tensor * cls_out,

0 commit comments

Comments
 (0)