@@ -4673,6 +4673,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
4673
4673
ggml_backend_sched_reset (sched.get ());
4674
4674
4675
4675
cross->t_embd = t_embd;
4676
+ cross->v_embd = embd;
4676
4677
4677
4678
// remember the sequence ids used during the encoding - needed for cross attention later
4678
4679
cross->seq_ids_enc .resize (n_tokens);
@@ -4701,12 +4702,11 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
4701
4702
// call base functionality
4702
4703
llama_context_kv_self::input_set (ubatch);
4703
4704
4704
- // if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
4705
- // assert(inp.cross_embd->type == GGML_TYPE_F32);
4706
- // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
4705
+ if (inp.cross_embd && cross->t_embd ) {
4706
+ assert (inp.cross_embd ->type == GGML_TYPE_F32);
4707
4707
4708
- // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc , 0, ggml_nbytes(inp.cross_embd));
4709
- // }
4708
+ ggml_backend_tensor_set (inp.cross_embd , cross->v_embd , 0 , ggml_nbytes (inp.cross_embd ));
4709
+ }
4710
4710
4711
4711
if (inp.cross_kq_mask ) {
4712
4712
const int64_t n_enc = inp.cross_kq_mask ->ne [0 ];
@@ -4749,16 +4749,17 @@ ggml_cgraph * llama_context_dec::graph_init() {
4749
4749
ggml_tensor * llama_context_dec::build_inp_cross_embd (
4750
4750
ggml_context * ctx0) {
4751
4751
// if we have the output embeddings from the encoder, use them directly
4752
- if (cross->t_embd ) {
4753
- inp.cross_embd = ggml_view_tensor (ctx0, cross->t_embd );
4752
+ // TODO: needs more work to be correct, for now just use the tensor shape
4753
+ // if (cross->t_embd) {
4754
+ // inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
4754
4755
4755
- return inp.cross_embd ;
4756
- }
4756
+ // return inp.cross_embd;
4757
+ // }
4757
4758
4758
4759
const auto & hparams = model.hparams ;
4759
4760
4760
- const auto n_embd = hparams.n_embd ;
4761
- const auto n_enc = hparams.n_ctx_train ;
4761
+ const auto n_embd = cross-> t_embd ? cross-> t_embd -> ne [ 0 ] : hparams.n_embd ;
4762
+ const auto n_enc = cross-> t_embd ? cross-> t_embd -> ne [ 1 ] : hparams.n_ctx_train ;
4762
4763
4763
4764
inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_enc);
4764
4765
ggml_set_input (inp.cross_embd );
0 commit comments