Skip to content

Commit 4863fc5

Browse files
ggerganovMinh141120
authored andcommitted
llama : rework embeddings logic (ggml-org#14208)
* llama : rework embeddings logic ggml-ci * cont : fix rerank ggml-ci * cont : engrish [no ci] * cont : fix rerank ggml-ci * server : support both embeddings and completions with single model ggml-ci * cont : avoid embeddings_org ggml-ci
1 parent e9bf3df commit 4863fc5

13 files changed

+117
-45
lines changed

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ struct common_params {
359359
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
360360
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
361361
std::string embd_sep = "\n"; // separator of embeddings
362-
std::string cls_sep = "\t"; // separator of classification sequences
363362

364363
// server params
365364
int32_t port = 8080; // server listens on this network port

include/llama.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ extern "C" {
258258
// - if embeddings: all tokens are output
259259
// - if not: only the last token is output
260260
// )
261+
// (if set to NULL:
262+
// - if embeddings: all tokens are output
263+
// - if not: only the last token is output
264+
// )
261265
//
262266
typedef struct llama_batch {
263267
int32_t n_tokens;
@@ -968,7 +972,6 @@ extern "C" {
968972
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
969973

970974
// Set whether the context outputs embeddings or not
971-
// TODO: rename to avoid confusion with llama_get_embeddings()
972975
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
973976

974977
// Set whether to use causal attention or not

src/llama-batch.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,8 @@ llama_batch_allocr::llama_batch_allocr() {
637637
bool llama_batch_allocr::init(
638638
const llama_batch & batch_inp,
639639
const llama_vocab & vocab,
640-
const llama_memory_i * memory) {
640+
const llama_memory_i * memory,
641+
bool embd_all) {
641642
clear();
642643

643644
batch = batch_inp;
@@ -716,10 +717,31 @@ bool llama_batch_allocr::init(
716717
}
717718

718719
if (!batch.logits) {
719-
// by default return the output only for the last token
720-
output.resize(batch.n_tokens);
721-
output[output.size() - 1] = true;
720+
if (embd_all) {
721+
// return the output for all tokens
722+
output.resize(batch.n_tokens, true);
723+
} else {
724+
// return the output only for the last token
725+
output.resize(batch.n_tokens, false);
726+
output[output.size() - 1] = true;
727+
}
728+
722729
batch.logits = output.data();
730+
} else if (embd_all) {
731+
bool warn = false;
732+
733+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
734+
if (batch.logits[i] == 0) {
735+
warn = true;
736+
}
737+
}
738+
739+
if (warn) {
740+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
741+
742+
output.resize(batch.n_tokens, true);
743+
batch.logits = output.data();
744+
}
723745
}
724746

725747
//

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class llama_batch_allocr {
9292
bool init(
9393
const llama_batch & batch_inp,
9494
const llama_vocab & vocab,
95-
const llama_memory_i * memory);
95+
const llama_memory_i * memory,
96+
bool embd_all);
9697

9798
const llama_batch & get_batch() const;
9899

src/llama-context.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728
}
729729

730730
// note: during encode, we always pass the full sequence starting from pos = 0
731-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
731+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732732
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733733
return -1;
734734
}
@@ -899,7 +899,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
899899
return -1;
900900
}
901901

902-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
902+
// when computing embeddings, all tokens are output
903+
const bool embd_all = cparams.embeddings;
904+
905+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
903906
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904907
return -1;
905908
}
@@ -916,12 +919,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
916919

917920
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
918921

919-
// this indicates we are doing pooled embedding
920-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
921-
922922
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
923923

924-
if (embd_pooled) {
924+
if (embd_all) {
925925
// require that all tokens are output
926926
if (n_outputs_all != n_tokens_all) {
927927
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -950,7 +950,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
950950
llama_memory_context_ptr mctx;
951951

952952
while (true) {
953-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
953+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
954954
if (!mstate) {
955955
return -2;
956956
}
@@ -2052,14 +2052,11 @@ void llama_context::opt_epoch_iter(
20522052

20532053
n_queued_tokens += n_tokens_all;
20542054

2055-
// this indicates we are doing pooled embedding
2056-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2057-
20582055
embd_seq.clear();
20592056

20602057
uint32_t n_outputs_all = n_tokens_all;
20612058

2062-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
2059+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
20632060
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20642061
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20652062
break;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
99-
GGML_UNUSED(embd_pooled);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99+
GGML_UNUSED(embd_all);
100100

101101
// first try simple split
102102
do {

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3434
llama_memory_context_ptr init_batch(
3535
llama_batch_allocr & balloc,
3636
uint32_t n_ubatch,
37-
bool embd_pooled) override;
37+
bool embd_all) override;
3838

3939
llama_memory_context_ptr init_full() override;
4040

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
339339
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
340340
llama_batch_allocr & balloc,
341341
uint32_t n_ubatch,
342-
bool embd_pooled) {
343-
GGML_UNUSED(embd_pooled);
342+
bool embd_all) {
343+
GGML_UNUSED(embd_all);
344344

345345
do {
346346
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

src/llama-kv-cache-unified.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
5959
llama_memory_context_ptr init_batch(
6060
llama_batch_allocr & balloc,
6161
uint32_t n_ubatch,
62-
bool embd_pooled) override;
62+
bool embd_all) override;
6363

6464
llama_memory_context_ptr init_full() override;
6565

src/llama-memory-recurrent.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -362,21 +362,19 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
362362
return result;
363363
}
364364

365-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
366-
GGML_UNUSED(embd_pooled);
367-
365+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
368366
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
369367

370368
std::vector<llama_ubatch> ubatches;
371369
while (true) {
372370
llama_ubatch ubatch;
373371

374-
if (embd_all) {
375-
// if all tokens are output, split by sequence
376-
ubatch = balloc.split_seq(n_ubatch);
377-
} else {
378-
ubatch = balloc.split_equal(n_ubatch);
379-
}
372+
if (embd_all) {
373+
// if all tokens are output, split by sequence
374+
ubatch = sbatch.split_seq(n_ubatch);
375+
} else {
376+
ubatch = sbatch.split_equal(n_ubatch);
377+
}
380378

381379
if (ubatch.n_tokens == 0) {
382380
break;

0 commit comments

Comments
 (0)