Skip to content

Commit f5e8020

Browse files
committed
wip enc-dec
1 parent c4c0a4d commit f5e8020

File tree

6 files changed

+72
-14
lines changed

6 files changed

+72
-14
lines changed

src/llama-context.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
llama_context::llama_context(
1818
const llama_model & model,
19-
const llama_context_params & params) :
20-
model (model) {
19+
const llama_context_params & params,
20+
llama_graph_type gtype) :
21+
llama_graph_i(gtype),
22+
model(model) {
2123
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2224

2325
t_start_us = model.t_start_us;
@@ -2279,8 +2281,9 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
22792281

22802282
llama_context_kv_self::llama_context_kv_self(
22812283
const llama_model & model,
2282-
const llama_context_params & params) :
2283-
llama_context(model, params),
2284+
const llama_context_params & params,
2285+
llama_graph_type gtype) :
2286+
llama_context(model, params, gtype),
22842287
kv_self(model.hparams) {
22852288
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
22862289

@@ -3750,8 +3753,9 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
37503753

37513754
llama_context_recurrent::llama_context_recurrent(
37523755
const llama_model & model,
3753-
const llama_context_params & params) :
3754-
llama_context(model, params),
3756+
const llama_context_params & params,
3757+
llama_graph_type gtype) :
3758+
llama_context(model, params, gtype),
37553759
kv_self(model.hparams) {
37563760
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
37573761

@@ -4619,6 +4623,22 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
46194623
return io.n_bytes();
46204624
}
46214625

4626+
//
4627+
// llama_context_enc_dec
4628+
//
4629+
4630+
llama_context_enc_dec::llama_context_enc_dec(
4631+
const llama_model & model,
4632+
const llama_context_params & params) :
4633+
llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
4634+
ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
4635+
LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__);
4636+
}
4637+
4638+
llama_context_enc_dec::~llama_context_enc_dec() {
4639+
LLAMA_LOG_INFO("%s: destructing llama_context_enc_dec\n", __func__);
4640+
}
4641+
46224642
//
46234643
// interface implementation
46244644
//

src/llama-context.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ struct llama_context : public llama_graph_i {
2525
public:
2626
llama_context(
2727
const llama_model & model,
28-
const llama_context_params & params);
28+
const llama_context_params & params,
29+
llama_graph_type gtype);
2930

3031
virtual ~llama_context();
3132

@@ -388,7 +389,8 @@ class llama_context_kv_self : public llama_context {
388389
public:
389390
llama_context_kv_self(
390391
const llama_model & model,
391-
const llama_context_params & params);
392+
const llama_context_params & params,
393+
llama_graph_type gtype);
392394

393395
virtual ~llama_context_kv_self();
394396

@@ -500,7 +502,8 @@ class llama_context_recurrent : public llama_context {
500502
public:
501503
llama_context_recurrent(
502504
const llama_model & model,
503-
const llama_context_params & params);
505+
const llama_context_params & params,
506+
llama_graph_type gtype);
504507

505508
virtual ~llama_context_recurrent();
506509

@@ -604,6 +607,23 @@ class llama_context_recurrent : public llama_context {
604607
llama_kv_cache_recurrent kv_self;
605608
};
606609

610+
class llama_context_enc : public llama_context {
611+
public:
612+
using llama_context::llama_context;
613+
};
614+
615+
class llama_context_enc_dec : public llama_context {
616+
public:
617+
llama_context_enc_dec(
618+
const llama_model & model,
619+
const llama_context_params & params);
620+
621+
virtual ~llama_context_enc_dec();
622+
623+
protected:
624+
llama_context_kv_self ctx_dec;
625+
};
626+
607627
// For internal test use
608628
// TODO: remove
609629
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

src/llama-graph.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "llama-impl.h"
44

5+
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
6+
57
ggml_tensor * llama_graph_i::build_attn(
68
ggml_context * ctx0,
79
ggml_cgraph * gf,

src/llama-graph.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ struct ggml_tensor;
1111
struct ggml_backend_buffer;
1212
struct llama_ubatch;
1313

14+
enum llama_graph_type {
15+
LLAMA_GRAPH_TYPE_DEFAULT,
16+
LLAMA_GRAPH_TYPE_ENCODER,
17+
LLAMA_GRAPH_TYPE_DECODER,
18+
};
19+
1420
struct llama_graph_result {
1521
// important graph nodes
1622
ggml_tensor * t_logits = nullptr;
@@ -20,6 +26,15 @@ struct llama_graph_result {
2026

2127
// TODO: can become more granular in the future
2228
class llama_graph_i {
29+
public:
30+
llama_graph_i(llama_graph_type type);
31+
virtual ~llama_graph_i() = default;
32+
33+
llama_graph_type get_type() const { return type; }
34+
35+
protected:
36+
llama_graph_type type;
37+
2338
public:
2439
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
2540
virtual void build_cb(

src/llama-model.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include "llama-hparams.h"
66
#include "llama-vocab.h"
77

8-
#include "ggml-cpp.h"
9-
108
#include <memory>
119
#include <string>
1210
#include <unordered_map>

src/llama.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,17 +331,20 @@ struct llama_context * llama_init_from_model(
331331
case LLM_ARCH_BERT:
332332
case LLM_ARCH_JINA_BERT_V2:
333333
case LLM_ARCH_NOMIC_BERT:
334-
ctx = new llama_context(*model, params);
334+
ctx = new llama_context_enc(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
335+
break;
336+
case LLM_ARCH_T5:
337+
ctx = new llama_context_enc_dec(*model, params);
335338
break;
336339
case LLM_ARCH_RWKV6:
337340
case LLM_ARCH_RWKV6QWEN2:
338341
case LLM_ARCH_MAMBA:
339342
GGML_ASSERT(llama_model_is_recurrent(model));
340-
ctx = new llama_context_recurrent(*model, params);
343+
ctx = new llama_context_recurrent(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
341344
break;
342345
default:
343346
GGML_ASSERT(!llama_model_is_recurrent(model));
344-
ctx = new llama_context_kv_self(*model, params);
347+
ctx = new llama_context_kv_self(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
345348
};
346349

347350
ctx->init();

0 commit comments

Comments
 (0)