Skip to content

Commit 952feed

Browse files
committed
context : disable encoder embd tensor for now
ggml-ci
1 parent 4efe989 commit 952feed

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/llama-context.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4673,6 +4673,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46734673
ggml_backend_sched_reset(sched.get());
46744674

46754675
cross->t_embd = t_embd;
4676+
cross->v_embd = embd;
46764677

46774678
// remember the sequence ids used during the encoding - needed for cross attention later
46784679
cross->seq_ids_enc.resize(n_tokens);
@@ -4701,12 +4702,11 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47014702
// call base functionality
47024703
llama_context_kv_self::input_set(ubatch);
47034704

4704-
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
4705-
// assert(inp.cross_embd->type == GGML_TYPE_F32);
4706-
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
4705+
if (inp.cross_embd && cross->t_embd) {
4706+
assert(inp.cross_embd->type == GGML_TYPE_F32);
47074707

4708-
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
4709-
//}
4708+
ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd));
4709+
}
47104710

47114711
if (inp.cross_kq_mask) {
47124712
const int64_t n_enc = inp.cross_kq_mask->ne[0];
@@ -4749,16 +4749,17 @@ ggml_cgraph * llama_context_dec::graph_init() {
47494749
ggml_tensor * llama_context_dec::build_inp_cross_embd(
47504750
ggml_context * ctx0) {
47514751
// if we have the output embeddings from the encoder, use them directly
4752-
if (cross->t_embd) {
4753-
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
4752+
// TODO: needs more work to be correct, for now just use the tensor shape
4753+
//if (cross->t_embd) {
4754+
// inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
47544755

4755-
return inp.cross_embd;
4756-
}
4756+
// return inp.cross_embd;
4757+
//}
47574758

47584759
const auto & hparams = model.hparams;
47594760

4760-
const auto n_embd = hparams.n_embd;
4761-
const auto n_enc = hparams.n_ctx_train;
4761+
const auto n_embd = cross->t_embd ? cross->t_embd->ne[0] : hparams.n_embd;
4762+
const auto n_enc = cross->t_embd ? cross->t_embd->ne[1] : hparams.n_ctx_train;
47624763

47634764
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
47644765
ggml_set_input(inp.cross_embd);

src/llama-context.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,14 @@ class llama_context_recurrent : public llama_context_base {
750750

751751
// TODO: tmp - need something better to pass the data from the encoder to the decoder
752752
struct llama_cross {
753-
// the output embeddings from the encoder
753+
// the output embeddings from the encoder as a ggml tensor
754+
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
755+
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
754756
ggml_tensor * t_embd = nullptr;
755757

758+
// embeddings data copied to host memory (tmp)
759+
float * v_embd = nullptr;
760+
756761
// needed to construct the cross-attention mask in the decoder
757762
std::vector<std::set<llama_seq_id>> seq_ids_enc;
758763
};

0 commit comments

Comments
 (0)