@@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
4540
4540
// llama_context_enc
4541
4541
//
4542
4542
4543
+ // TODO: avoid copy-paste of the entire encode() function
4543
4544
int llama_context_enc::encode (llama_batch & inp_batch) {
4544
4545
if (inp_batch.n_tokens == 0 ) {
4545
4546
LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
@@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
4671
4672
// overlap with device computation.
4672
4673
ggml_backend_sched_reset (sched.get ());
4673
4674
4674
- cross->n_outputs = n_tokens;
4675
- cross->embd_enc = embd;
4675
+ cross->t_embd = t_embd;
4676
4676
4677
4677
// remember the sequence ids used during the encoding - needed for cross attention later
4678
4678
cross->seq_ids_enc .resize (n_tokens);
@@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
4692
4692
4693
4693
void llama_context_dec::reserve () {
4694
4694
// simulate full KV cache
4695
- cross->n_outputs = cparams.n_ubatch ;
4696
-
4697
- LLAMA_LOG_DEBUG (" %s: n_outputs = %u\n " , __func__, cross->n_outputs );
4695
+ cross->t_embd = nullptr ;
4698
4696
4699
4697
llama_context_kv_self::reserve ();
4700
4698
}
@@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
4703
4701
// call base functionality
4704
4702
llama_context_kv_self::input_set (ubatch);
4705
4703
4706
- if (inp.cross_embd ) {
4707
- assert (inp.cross_embd ->type == GGML_TYPE_F32);
4708
- assert (ggml_nelements (inp.cross_embd ) == cross->n_outputs *model.hparams .n_embd );
4704
+ // if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE ) {
4705
+ // assert(inp.cross_embd->type == GGML_TYPE_F32);
4706
+ // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
4709
4707
4710
- ggml_backend_tensor_set (inp.cross_embd , cross->embd_enc , 0 , ggml_nbytes (inp.cross_embd ));
4711
- }
4708
+ // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
4709
+ // }
4712
4710
4713
4711
if (inp.cross_kq_mask ) {
4714
- const int64_t n_output_enc = cross-> n_outputs ;
4712
+ const int64_t n_enc = inp. cross_kq_mask -> ne [ 0 ] ;
4715
4713
const int64_t n_tokens = ubatch.n_tokens ;
4716
4714
4717
4715
GGML_ASSERT (ggml_backend_buffer_is_host (inp.cross_kq_mask ->buffer ));
@@ -4721,21 +4719,21 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
4721
4719
4722
4720
for (int h = 0 ; h < 1 ; ++h) {
4723
4721
for (int j = 0 ; j < n_tokens; ++j) {
4724
- for (int i = 0 ; i < n_output_enc ; ++i) {
4722
+ for (int i = 0 ; i < n_enc ; ++i) {
4725
4723
float f = -INFINITY;
4726
4724
for (int s = 0 ; s < ubatch.n_seq_id [j]; ++s) {
4727
4725
const llama_seq_id seq_id = ubatch.seq_id [j][s];
4728
4726
if (cross->seq_ids_enc [i].find (seq_id) != cross->seq_ids_enc [i].end ()) {
4729
4727
f = 0 .0f ;
4730
4728
}
4731
4729
}
4732
- data[h*(n_output_enc *n_tokens) + j*n_output_enc + i] = f;
4730
+ data[h*(n_enc *n_tokens) + j*n_enc + i] = f;
4733
4731
}
4734
4732
}
4735
4733
4736
4734
for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
4737
- for (int j = 0 ; j < n_output_enc ; ++j) {
4738
- data[h*(n_output_enc *n_tokens) + i*n_output_enc + j] = -INFINITY;
4735
+ for (int j = 0 ; j < n_enc ; ++j) {
4736
+ data[h*(n_enc *n_tokens) + i*n_enc + j] = -INFINITY;
4739
4737
}
4740
4738
}
4741
4739
}
@@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
4750
4748
4751
4749
ggml_tensor * llama_context_dec::build_inp_cross_embd (
4752
4750
ggml_context * ctx0) {
4751
+ // if we have the output embeddings from the encoder, use them directly
4752
+ if (cross->t_embd ) {
4753
+ inp.cross_embd = ggml_view_tensor (ctx0, cross->t_embd );
4754
+
4755
+ return inp.cross_embd ;
4756
+ }
4757
+
4753
4758
const auto & hparams = model.hparams ;
4754
- const int64_t n_embd = hparams.n_embd ;
4755
4759
4756
- const int32_t n_outputs_enc = cross->n_outputs ;
4760
+ const auto n_embd = hparams.n_embd ;
4761
+ const auto n_enc = hparams.n_ctx_train ;
4757
4762
4758
- inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc );
4763
+ inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_enc );
4759
4764
ggml_set_input (inp.cross_embd );
4760
4765
4761
4766
return inp.cross_embd ;
@@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
4768
4773
bool swa) {
4769
4774
llama_context_kv_self::build_attn_inp (ctx0, n_tokens, causal, swa);
4770
4775
4771
- const int32_t n_outputs_enc = cross->n_outputs ;
4776
+ const int32_t n_enc = cross->t_embd ? cross-> t_embd -> ne [ 1 ] : model. hparams . n_ctx_train ;
4772
4777
4773
- inp.cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_outputs_enc , GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
4778
+ inp.cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc , GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
4774
4779
ggml_set_input (inp.cross_kq_mask );
4775
4780
4776
4781
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp.cross_kq_mask , GGML_TYPE_F16) : inp.cross_kq_mask ;
0 commit comments