Skip to content

Commit bc0a20c

Browse files
committed
graph : refactor context to not pass gf explicitly
ggml-ci
1 parent 3d28b3b commit bc0a20c

File tree

5 files changed

+280
-326
lines changed

5 files changed

+280
-326
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ bool llama_context::apply_adapter_cvec(
674674
return cvec.apply(model, data, len, n_embd, il_start, il_end);
675675
}
676676

677-
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
677+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678678
if (mctx && !mctx->apply()) {
679679
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
680680
ret = GGML_STATUS_FAILED;
@@ -1324,7 +1324,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13241324
}
13251325

13261326
llm_graph_params llama_context::graph_params(
1327-
llm_graph_result_i * res,
1327+
llm_graph_result * res,
13281328
const llama_ubatch & ubatch,
13291329
const llama_memory_context_i * mctx,
13301330
llm_graph_type gtype) const {

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct llama_context {
9494
// if memory_context is provided, it will be applied first to the context's memory
9595
// ret contains the status of the graph computation
9696
// returns nullptr only if ret != GGML_STATUS_SUCCESS
97-
llm_graph_result_i * process_ubatch(
97+
llm_graph_result * process_ubatch(
9898
const llama_ubatch & ubatch,
9999
llm_graph_type gtype,
100100
llama_memory_context_i * mctx,
@@ -196,7 +196,7 @@ struct llama_context {
196196

197197
private:
198198
llm_graph_params graph_params(
199-
llm_graph_result_i * res,
199+
llm_graph_result * res,
200200
const llama_ubatch & ubatch,
201201
const llama_memory_context_i * mctx,
202202
llm_graph_type gtype) const;

src/llama-graph.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448448
mctx (params.mctx),
449449
cross (params.cross),
450450
cb_func (params.cb),
451-
res (static_cast<llm_graph_result *>(params.res)),
452-
ctx0 (res->get_ctx()) {
453-
res->params = params;
451+
res (params.res),
452+
ctx0 (res->get_ctx()),
453+
gf (res->get_gf()) {
454+
res->set_params(params);
454455
}
455456

456457
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1040,7 +1041,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10401041
}
10411042

10421043
ggml_tensor * llm_graph_context::build_attn_mha(
1043-
ggml_cgraph * gf,
10441044
ggml_tensor * q,
10451045
ggml_tensor * k,
10461046
ggml_tensor * v,
@@ -1170,7 +1170,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
11701170

11711171
ggml_tensor * llm_graph_context::build_attn(
11721172
llm_graph_input_attn_no_cache * inp,
1173-
ggml_cgraph * gf,
11741173
ggml_tensor * wo,
11751174
ggml_tensor * wo_b,
11761175
ggml_tensor * q_cur,
@@ -1194,7 +1193,7 @@ ggml_tensor * llm_graph_context::build_attn(
11941193
ggml_tensor * k = k_cur;
11951194
ggml_tensor * v = v_cur;
11961195

1197-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1196+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
11981197
cb(cur, "kqv_out", il);
11991198

12001199
if (wo) {
@@ -1249,7 +1248,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12491248

12501249
ggml_tensor * llm_graph_context::build_attn(
12511250
llm_graph_input_attn_kv_unified * inp,
1252-
ggml_cgraph * gf,
12531251
ggml_tensor * wo,
12541252
ggml_tensor * wo_b,
12551253
ggml_tensor * q_cur,
@@ -1282,7 +1280,7 @@ ggml_tensor * llm_graph_context::build_attn(
12821280
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
12831281
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
12841282

1285-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1283+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12861284
cb(cur, "kqv_out", il);
12871285

12881286
if (wo) {
@@ -1302,7 +1300,6 @@ ggml_tensor * llm_graph_context::build_attn(
13021300

13031301
ggml_tensor * llm_graph_context::build_attn(
13041302
llm_graph_input_attn_kv_unified_iswa * inp,
1305-
ggml_cgraph * gf,
13061303
ggml_tensor * wo,
13071304
ggml_tensor * wo_b,
13081305
ggml_tensor * q_cur,
@@ -1349,7 +1346,7 @@ ggml_tensor * llm_graph_context::build_attn(
13491346
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
13501347
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
13511348

1352-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1349+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13531350
cb(cur, "kqv_out", il);
13541351

13551352
if (wo) {
@@ -1382,7 +1379,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13821379

13831380
ggml_tensor * llm_graph_context::build_attn(
13841381
llm_graph_input_attn_cross * inp,
1385-
ggml_cgraph * gf,
13861382
ggml_tensor * wo,
13871383
ggml_tensor * wo_b,
13881384
ggml_tensor * q_cur,
@@ -1404,7 +1400,7 @@ ggml_tensor * llm_graph_context::build_attn(
14041400
ggml_tensor * k = k_cur;
14051401
ggml_tensor * v = v_cur;
14061402

1407-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1403+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14081404
cb(cur, "kqv_out", il);
14091405

14101406
if (wo) {
@@ -1460,7 +1456,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14601456
}
14611457

14621458
ggml_tensor * llm_graph_context::build_rs(
1463-
ggml_cgraph * gf,
14641459
ggml_tensor * s,
14651460
ggml_tensor * state_copy,
14661461
int32_t state_size,
@@ -1518,21 +1513,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
15181513

15191514
ggml_tensor * llm_graph_context::build_rs(
15201515
llm_graph_input_rs * inp,
1521-
ggml_cgraph * gf,
15221516
ggml_tensor * s,
15231517
int32_t state_size,
15241518
int32_t n_seqs,
15251519
const llm_graph_get_rows_fn & get_state_rows) const {
15261520
const auto * kv_state = inp->mctx;
15271521

1528-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1522+
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
15291523
}
15301524

15311525
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15321526
llm_graph_input_rs * inp,
1533-
ggml_cgraph * gf,
15341527
const llama_ubatch & ubatch,
1535-
int il) const {
1528+
int il) const {
15361529
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
15371530

15381531
const auto token_shift_count = hparams.token_shift_count;
@@ -1542,7 +1535,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15421535
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
15431536

15441537
ggml_tensor * token_shift = build_rs(
1545-
inp, gf, token_shift_all,
1538+
inp, token_shift_all,
15461539
hparams.n_embd_r(), n_seqs);
15471540

15481541
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1582,7 +1575,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
15821575
}
15831576

15841577
void llm_graph_context::build_pooling(
1585-
ggml_cgraph * gf,
15861578
ggml_tensor * cls,
15871579
ggml_tensor * cls_b,
15881580
ggml_tensor * cls_out,

src/llama-graph.h

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372
// these are used by the llama_context to extact the relevant data, based on the compute parameters
373373

374-
// TODO: this interface seems redundant - remove it
375-
class llm_graph_result_i {
376-
public:
377-
virtual ~llm_graph_result_i() = default;
378-
379-
virtual ggml_tensor * get_tokens() const = 0;
380-
virtual ggml_tensor * get_logits() const = 0;
381-
virtual ggml_tensor * get_embd() const = 0;
382-
virtual ggml_tensor * get_embd_pooled() const = 0;
383-
384-
virtual ggml_cgraph * get_gf() = 0;
385-
virtual ggml_context * get_ctx() = 0;
386-
387-
virtual void reset() = 0;
388-
389-
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
390-
391-
virtual bool can_reuse(const llm_graph_params & params) = 0;
392-
};
393-
394-
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395-
396374
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398376

377+
class llm_graph_result;
378+
399379
struct llm_graph_params {
400380
llm_arch arch = LLM_ARCH_UNKNOWN;
401381

@@ -418,8 +398,7 @@ struct llm_graph_params {
418398

419399
llm_graph_cb cb;
420400

421-
// TODO: temporary
422-
llm_graph_result_i * res;
401+
llm_graph_result * res;
423402

424403
// return true if the "other" params would result in a graph with the same topology as with the current params
425404
// having the same topology allows us to reuse the graph in some cases
@@ -462,27 +441,27 @@ struct llm_graph_params {
462441
}
463442
};
464443

465-
class llm_graph_result : public llm_graph_result_i {
444+
class llm_graph_result {
466445
public:
467446
llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
468447
reset();
469448
}
470449

471450
virtual ~llm_graph_result() = default;
472451

473-
ggml_tensor * get_tokens() const override { return t_tokens; }
474-
ggml_tensor * get_logits() const override { return t_logits; }
475-
ggml_tensor * get_embd() const override { return t_embd; }
476-
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
452+
ggml_tensor * get_tokens() const { return t_tokens; }
453+
ggml_tensor * get_logits() const { return t_logits; }
454+
ggml_tensor * get_embd() const { return t_embd; }
455+
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
477456

478-
ggml_cgraph * get_gf() override { return gf; }
479-
ggml_context * get_ctx() override { return ctx_compute.get(); }
457+
ggml_cgraph * get_gf() { return gf; }
458+
ggml_context * get_ctx() { return ctx_compute.get(); }
480459

481460
void set_max_nodes(int64_t max_nodes) {
482461
this->max_nodes = max_nodes;
483462
}
484463

485-
void reset() override {
464+
void reset() {
486465
t_tokens = nullptr;
487466
t_logits = nullptr;
488467
t_embd = nullptr;
@@ -503,7 +482,7 @@ class llm_graph_result : public llm_graph_result_i {
503482
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
504483
}
505484

506-
void set_inputs(const llama_ubatch * ubatch) override {
485+
void set_inputs(const llama_ubatch * ubatch) {
507486
for (auto & input : inputs) {
508487
input->set_input(ubatch);
509488
}
@@ -514,7 +493,7 @@ class llm_graph_result : public llm_graph_result_i {
514493
// would be identical to the existing graph. in that case, we simply have to update the memory
515494
// contexts of the input tensors of the graph and we can reuse it for another computation
516495
// return true if the graph was updated and can be reused
517-
bool can_reuse(const llm_graph_params & params) override {
496+
bool can_reuse(const llm_graph_params & params) {
518497
if (!this->params.allow_reuse(params)) {
519498
return false;
520499
}
@@ -533,6 +512,10 @@ class llm_graph_result : public llm_graph_result_i {
533512
return inputs.back().get();
534513
}
535514

515+
void set_params(const llm_graph_params & params) {
516+
this->params = params;
517+
}
518+
536519
// important graph nodes
537520
ggml_tensor * t_tokens = nullptr;
538521
ggml_tensor * t_logits = nullptr;
@@ -550,12 +533,15 @@ class llm_graph_result : public llm_graph_result_i {
550533

551534
int64_t max_nodes;
552535

536+
private:
553537
// keep a copy of the previous graph parameters
554538
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
555539
// note: these are updated after constructing the new graph
556540
llm_graph_params params;
557541
};
558542

543+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
544+
559545
//
560546
// llm_graph_context
561547
//
@@ -613,6 +599,7 @@ struct llm_graph_context {
613599
llm_graph_result * res;
614600

615601
ggml_context * ctx0 = nullptr;
602+
ggml_cgraph * gf = nullptr;
616603

617604
llm_graph_context(const llm_graph_params & params);
618605
virtual ~llm_graph_context() = default;
@@ -698,7 +685,6 @@ struct llm_graph_context {
698685
//
699686

700687
ggml_tensor * build_attn_mha(
701-
ggml_cgraph * gf,
702688
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
703689
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
704690
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -711,7 +697,6 @@ struct llm_graph_context {
711697

712698
ggml_tensor * build_attn(
713699
llm_graph_input_attn_no_cache * inp,
714-
ggml_cgraph * gf,
715700
ggml_tensor * wo,
716701
ggml_tensor * wo_b,
717702
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -726,7 +711,6 @@ struct llm_graph_context {
726711

727712
ggml_tensor * build_attn(
728713
llm_graph_input_attn_kv_unified * inp,
729-
ggml_cgraph * gf,
730714
ggml_tensor * wo,
731715
ggml_tensor * wo_b,
732716
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +726,6 @@ struct llm_graph_context {
742726
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
743727
ggml_tensor * build_attn(
744728
llm_graph_input_attn_kv_unified_iswa * inp,
745-
ggml_cgraph * gf,
746729
ggml_tensor * wo,
747730
ggml_tensor * wo_b,
748731
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -757,7 +740,6 @@ struct llm_graph_context {
757740

758741
ggml_tensor * build_attn(
759742
llm_graph_input_attn_cross * inp,
760-
ggml_cgraph * gf,
761743
ggml_tensor * wo,
762744
ggml_tensor * wo_b,
763745
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -779,7 +761,6 @@ struct llm_graph_context {
779761
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
780762
// `llama_memory_recurrent`
781763
ggml_tensor * build_rs(
782-
ggml_cgraph * gf,
783764
ggml_tensor * s,
784765
ggml_tensor * state_copy,
785766
int32_t state_size,
@@ -794,17 +775,15 @@ struct llm_graph_context {
794775

795776
ggml_tensor * build_rs(
796777
llm_graph_input_rs * inp,
797-
ggml_cgraph * gf,
798778
ggml_tensor * s,
799779
int32_t state_size,
800780
int32_t n_seqs,
801781
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
802782

803783
ggml_tensor * build_rwkv_token_shift_load(
804784
llm_graph_input_rs * inp,
805-
ggml_cgraph * gf,
806785
const llama_ubatch & ubatch,
807-
int il) const;
786+
int il) const;
808787

809788
ggml_tensor * build_rwkv_token_shift_store(
810789
ggml_tensor * token_shift,
@@ -821,7 +800,6 @@ struct llm_graph_context {
821800
//
822801

823802
void build_pooling(
824-
ggml_cgraph * gf,
825803
ggml_tensor * cls,
826804
ggml_tensor * cls_b,
827805
ggml_tensor * cls_out,

0 commit comments

Comments
 (0)