Skip to content

Commit 4efe989

Browse files
committed
context : pass embeddings tensor from encoder to decoder
ggml-ci
1 parent e2b3294 commit 4efe989

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

src/llama-context.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
45404540
// llama_context_enc
45414541
//
45424542

4543+
// TODO: avoid copy-paste of the entire encode() function
45434544
int llama_context_enc::encode(llama_batch & inp_batch) {
45444545
if (inp_batch.n_tokens == 0) {
45454546
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46714672
// overlap with device computation.
46724673
ggml_backend_sched_reset(sched.get());
46734674

4674-
cross->n_outputs = n_tokens;
4675-
cross->embd_enc = embd;
4675+
cross->t_embd = t_embd;
46764676

46774677
// remember the sequence ids used during the encoding - needed for cross attention later
46784678
cross->seq_ids_enc.resize(n_tokens);
@@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46924692

46934693
void llama_context_dec::reserve() {
46944694
// simulate full KV cache
4695-
cross->n_outputs = cparams.n_ubatch;
4696-
4697-
LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs);
4695+
cross->t_embd = nullptr;
46984696

46994697
llama_context_kv_self::reserve();
47004698
}
@@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47034701
// call base functionality
47044702
llama_context_kv_self::input_set(ubatch);
47054703

4706-
if (inp.cross_embd) {
4707-
assert(inp.cross_embd->type == GGML_TYPE_F32);
4708-
assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
4704+
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
4705+
// assert(inp.cross_embd->type == GGML_TYPE_F32);
4706+
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
47094707

4710-
ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
4711-
}
4708+
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
4709+
//}
47124710

47134711
if (inp.cross_kq_mask) {
4714-
const int64_t n_output_enc = cross->n_outputs;
4712+
const int64_t n_enc = inp.cross_kq_mask->ne[0];
47154713
const int64_t n_tokens = ubatch.n_tokens;
47164714

47174715
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
@@ -4721,21 +4719,21 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47214719

47224720
for (int h = 0; h < 1; ++h) {
47234721
for (int j = 0; j < n_tokens; ++j) {
4724-
for (int i = 0; i < n_output_enc; ++i) {
4722+
for (int i = 0; i < n_enc; ++i) {
47254723
float f = -INFINITY;
47264724
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
47274725
const llama_seq_id seq_id = ubatch.seq_id[j][s];
47284726
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
47294727
f = 0.0f;
47304728
}
47314729
}
4732-
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
4730+
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
47334731
}
47344732
}
47354733

47364734
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
4737-
for (int j = 0; j < n_output_enc; ++j) {
4738-
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
4735+
for (int j = 0; j < n_enc; ++j) {
4736+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
47394737
}
47404738
}
47414739
}
@@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
47504748

47514749
ggml_tensor * llama_context_dec::build_inp_cross_embd(
47524750
ggml_context * ctx0) {
4751+
// if we have the output embeddings from the encoder, use them directly
4752+
if (cross->t_embd) {
4753+
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
4754+
4755+
return inp.cross_embd;
4756+
}
4757+
47534758
const auto & hparams = model.hparams;
4754-
const int64_t n_embd = hparams.n_embd;
47554759

4756-
const int32_t n_outputs_enc = cross->n_outputs;
4760+
const auto n_embd = hparams.n_embd;
4761+
const auto n_enc = hparams.n_ctx_train;
47574762

4758-
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
4763+
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
47594764
ggml_set_input(inp.cross_embd);
47604765

47614766
return inp.cross_embd;
@@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
47684773
bool swa) {
47694774
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
47704775

4771-
const int32_t n_outputs_enc = cross->n_outputs;
4776+
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
47724777

4773-
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
4778+
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
47744779
ggml_set_input(inp.cross_kq_mask);
47754780

47764781
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;

src/llama-context.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,11 +748,12 @@ class llama_context_recurrent : public llama_context_base {
748748
llama_kv_cache_recurrent kv_self;
749749
};
750750

751-
// TODO: tmp - need something better
751+
// TODO: tmp - need something better to pass the data from the encoder to the decoder
752752
struct llama_cross {
753-
int32_t n_outputs;
754-
float * embd_enc;
753+
// the output embeddings from the encoder
754+
ggml_tensor * t_embd = nullptr;
755755

756+
// needed to construct the cross-attention mask in the decoder
756757
std::vector<std::set<llama_seq_id>> seq_ids_enc;
757758
};
758759

0 commit comments

Comments
 (0)