@@ -7650,6 +7650,50 @@ struct llm_build_context {
7650
7650
return lctx.inp_s_seq;
7651
7651
}
7652
7652
7653
+ struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
7654
+ // find result_norm tensor for input
7655
+ struct ggml_tensor * inp = nullptr;
7656
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
7657
+ inp = gf->nodes[i];
7658
+ if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
7659
+ break;
7660
+ } else {
7661
+ inp = nullptr;
7662
+ }
7663
+ }
7664
+ GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
7665
+
7666
+ struct ggml_tensor * cur;
7667
+
7668
+ switch (pooling_type) {
7669
+ case LLAMA_POOLING_TYPE_MEAN:
7670
+ {
7671
+ struct ggml_tensor * inp_mean = build_inp_mean();
7672
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
7673
+ } break;
7674
+ case LLAMA_POOLING_TYPE_CLS:
7675
+ case LLAMA_POOLING_TYPE_LAST:
7676
+ {
7677
+ struct ggml_tensor * inp_cls = build_inp_cls();
7678
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
7679
+ } break;
7680
+ case LLAMA_POOLING_TYPE_NONE:
7681
+ {
7682
+ cur = inp;
7683
+ } break;
7684
+ default:
7685
+ {
7686
+ GGML_ASSERT(false && "unknown pooling type");
7687
+ } break;
7688
+ }
7689
+
7690
+ cb(cur, "result_embd_pooled", -1);
7691
+
7692
+ ggml_build_forward_expand(gf, cur);
7693
+
7694
+ return gf;
7695
+ }
7696
+
7653
7697
struct ggml_cgraph * build_llama() {
7654
7698
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7655
7699
@@ -8630,8 +8674,6 @@ struct llm_build_context {
8630
8674
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
8631
8675
inp_pos = build_inp_pos();
8632
8676
}
8633
- struct ggml_tensor * inp_mean = build_inp_mean();
8634
- struct ggml_tensor * inp_cls = build_inp_cls();
8635
8677
8636
8678
// construct input embeddings (token, type, position)
8637
8679
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -8806,28 +8848,6 @@ struct llm_build_context {
8806
8848
cur = inpL;
8807
8849
cb(cur, "result_embd", -1);
8808
8850
8809
- // pooling layer
8810
- switch (pooling_type) {
8811
- case LLAMA_POOLING_TYPE_NONE:
8812
- {
8813
- // nop
8814
- } break;
8815
- case LLAMA_POOLING_TYPE_MEAN:
8816
- {
8817
- cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
8818
- cb(cur, "result_embd_pooled", -1);
8819
- } break;
8820
- case LLAMA_POOLING_TYPE_CLS:
8821
- {
8822
- cur = ggml_get_rows(ctx0, cur, inp_cls);
8823
- cb(cur, "result_embd_pooled", -1);
8824
- } break;
8825
- case LLAMA_POOLING_TYPE_UNSPECIFIED:
8826
- {
8827
- GGML_ASSERT(false && "Invalid pooling type");
8828
- } break;
8829
- }
8830
-
8831
8851
ggml_build_forward_expand(gf, cur);
8832
8852
8833
8853
return gf;
@@ -11912,6 +11932,11 @@ static struct ggml_cgraph * llama_build_graph(
11912
11932
GGML_ASSERT(false);
11913
11933
}
11914
11934
11935
+ // add on pooling layer
11936
+ if (lctx.cparams.embeddings) {
11937
+ result = llm.append_pooling(result);
11938
+ }
11939
+
11915
11940
llm.free();
11916
11941
11917
11942
return result;
@@ -12001,7 +12026,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12001
12026
// (!a || b) is a logical implication (a -> b)
12002
12027
// !hparams.causal_attn -> !cparams.causal_attn
12003
12028
(hparams.causal_attn || !cparams.causal_attn) &&
12004
- "causal attention with embedding models is not supported"
12029
+ "causal attention is not supported by this model "
12005
12030
);
12006
12031
12007
12032
if (lctx.inp_KQ_mask) {
@@ -12133,6 +12158,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12133
12158
}
12134
12159
}
12135
12160
12161
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
12162
+ const int64_t n_tokens = batch.n_tokens;
12163
+
12164
+ GGML_ASSERT(lctx.inp_cls);
12165
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
12166
+
12167
+ uint32_t * data = (uint32_t *) lctx.inp_cls->data;
12168
+ memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
12169
+
12170
+ std::vector<int> last_pos(n_tokens, -1);
12171
+ std::vector<int> last_row(n_tokens, -1);
12172
+
12173
+ for (int i = 0; i < n_tokens; ++i) {
12174
+ const llama_seq_id seq_id = batch.seq_id[i][0];
12175
+ const llama_pos pos = batch.pos[i];
12176
+
12177
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
12178
+
12179
+ if (pos >= last_pos[seq_id]) {
12180
+ last_pos[seq_id] = pos;
12181
+ last_row[seq_id] = i;
12182
+ }
12183
+ }
12184
+
12185
+ for (int i = 0; i < n_tokens; ++i) {
12186
+ if (last_row[i] >= 0) {
12187
+ data[i] = last_row[i];
12188
+ }
12189
+ }
12190
+ }
12191
+
12136
12192
if (kv_self.recurrent) {
12137
12193
const int64_t n_kv = kv_self.n;
12138
12194
@@ -12194,8 +12250,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
12194
12250
const auto n_embd = hparams.n_embd;
12195
12251
12196
12252
// TODO: use a per-batch flag for logits presence instead
12197
- const bool has_logits = cparams.causal_attn ;
12198
- const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
12253
+ const bool has_logits = ! cparams.embeddings ;
12254
+ const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
12199
12255
12200
12256
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
12201
12257
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@@ -12325,11 +12381,13 @@ static int llama_decode_internal(
12325
12381
std::vector<std::vector<llama_seq_id>> seq_id;
12326
12382
12327
12383
// count outputs
12328
- if (batch_all.logits) {
12384
+ if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
12385
+ n_outputs = n_tokens_all;
12386
+ } else if (batch_all.logits) {
12329
12387
for (uint32_t i = 0; i < n_tokens_all; ++i) {
12330
12388
n_outputs += batch_all.logits[i] != 0;
12331
12389
}
12332
- } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) ) {
12390
+ } else if (lctx.logits_all) {
12333
12391
n_outputs = n_tokens_all;
12334
12392
} else {
12335
12393
// keep last output only
@@ -12460,30 +12518,13 @@ static int llama_decode_internal(
12460
12518
// no output
12461
12519
res = nullptr;
12462
12520
embd = nullptr;
12463
- } else if (!hparams.causal_attn) {
12464
- res = nullptr; // do not extract logits for embedding models such as BERT
12465
-
12466
- // token or sequence embeddings
12467
- embd = gf->nodes[gf->n_nodes - 1];
12468
-
12469
- GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
12470
12521
} else if (cparams.embeddings) {
12471
- // the embeddings could be in the second to last tensor, or any of the previous tensors
12472
- int i_embd = gf->n_nodes - 2;
12473
- for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
12474
- i_embd = gf->n_nodes - i;
12475
- if (i_embd < 0) { break; }
12476
- embd = gf->nodes[i_embd];
12477
- }
12478
- GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
12479
-
12480
- // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
12481
- if (!cparams.causal_attn) {
12482
- res = nullptr; // do not extract logits when not needed
12483
- // skip computing logits
12484
- // TODO: is this safe?
12485
- gf->n_nodes = i_embd + 1;
12522
+ res = nullptr; // do not extract logits for embedding case
12523
+ embd = gf->nodes[gf->n_nodes - 1];
12524
+ if (strcmp(embd->name, "result_embd_pooled") != 0) {
12525
+ embd = gf->nodes[gf->n_nodes - 2];
12486
12526
}
12527
+ GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
12487
12528
} else {
12488
12529
embd = nullptr; // do not extract embeddings when not needed
12489
12530
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@@ -12552,11 +12593,10 @@ static int llama_decode_internal(
12552
12593
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
12553
12594
}
12554
12595
} break;
12555
- case LLAMA_POOLING_TYPE_CLS:
12556
12596
case LLAMA_POOLING_TYPE_MEAN:
12597
+ case LLAMA_POOLING_TYPE_CLS:
12598
+ case LLAMA_POOLING_TYPE_LAST:
12557
12599
{
12558
- GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
12559
-
12560
12600
// extract sequence embeddings
12561
12601
auto & embd_seq_out = lctx.embd_seq;
12562
12602
embd_seq_out.clear();
@@ -18112,6 +18152,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
18112
18152
ctx->abort_callback_data = abort_callback_data;
18113
18153
}
18114
18154
18155
+ void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
18156
+ ctx->cparams.embeddings = embeddings;
18157
+ }
18158
+
18115
18159
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
18116
18160
ctx->cparams.causal_attn = causal_attn;
18117
18161
}
0 commit comments