Skip to content

Commit e27f8d7

Browse files
ggerganovqnixsynapse
authored andcommitted
llama : rework embeddings logic (ggml-org#14208)
* llama : rework embeddings logic ggml-ci * cont : fix rerank ggml-ci * cont : engrish [no ci] * cont : fix rerank ggml-ci * server : support both embeddings and completions with single model ggml-ci * cont : avoid embeddings_org ggml-ci
1 parent 1805319 commit e27f8d7

13 files changed

+99
-65
lines changed

common/common.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <string>
99
#include <string_view>
1010
#include <vector>
11-
#include <map>
1211
#include <sstream>
1312

1413
#ifdef _WIN32
@@ -200,9 +199,6 @@ struct common_params_speculative {
200199
float p_split = 0.1f; // speculative decoding split probability
201200
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
202201

203-
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
204-
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
205-
206202
struct cpu_params cpuparams;
207203
struct cpu_params cpuparams_batch;
208204

@@ -359,7 +355,6 @@ struct common_params {
359355
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
360356
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
361357
std::string embd_sep = "\n"; // separator of embeddings
362-
std::string cls_sep = "\t"; // separator of classification sequences
363358

364359
// server params
365360
int32_t port = 8080; // server listens on this network port
@@ -382,8 +377,6 @@ struct common_params {
382377
std::string ssl_file_key = ""; // NOLINT
383378
std::string ssl_file_cert = ""; // NOLINT
384379

385-
std::map<std::string, std::string> default_template_kwargs;
386-
387380
// "advanced" endpoints are disabled by default for better security
388381
bool webui = true;
389382
bool endpoint_slots = false;

include/llama.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ extern "C" {
390390
void * imatrix; // pointer to importance matrix data
391391
void * kv_overrides; // pointer to vector containing overrides
392392
void * tensor_types; // pointer to vector containing tensor types
393-
void * prune_layers; // pointer to vector containing layer indices to prune
394393
} llama_model_quantize_params;
395394

396395
typedef struct llama_logit_bias {
@@ -944,14 +943,12 @@ extern "C" {
944943
// Requires the context to have a memory.
945944
// For encode-decoder contexts, processes the batch using the decoder.
946945
// Positive return values does not mean a fatal error, but rather a warning.
947-
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
948-
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
949-
// Upon other return values, the memory state is restored to the state before this call
946+
// Upon non-zero return values, the memory state is restored to the state before this call
950947
// 0 - success
951948
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
952-
// 2 - aborted (processed ubatches will remain in the context's memory)
949+
// 2 - aborted
953950
// -1 - invalid input batch
954-
// < -1 - fatal error (processed ubatches will remain in the context's memory)
951+
// < -1 - error
955952
LLAMA_API int32_t llama_decode(
956953
struct llama_context * ctx,
957954
struct llama_batch batch);
@@ -968,7 +965,6 @@ extern "C" {
968965
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
969966

970967
// Set whether the context outputs embeddings or not
971-
// TODO: rename to avoid confusion with llama_get_embeddings()
972968
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
973969

974970
// Set whether to use causal attention or not
@@ -1047,7 +1043,6 @@ extern "C" {
10471043

10481044
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
10491045
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
1050-
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
10511046

10521047
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
10531048
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
@@ -1091,7 +1086,6 @@ extern "C" {
10911086
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
10921087
/// @return Returns the number of tokens on success, no more than n_tokens_max
10931088
/// @return Returns a negative number on failure - the number of tokens that would have been returned
1094-
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
10951089
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
10961090
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
10971091
/// as plaintext. Does not insert a leading space.

src/llama-batch.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
299299
bool llama_batch_allocr::init(
300300
const llama_batch & batch_inp,
301301
const llama_vocab & vocab,
302-
const llama_memory_i * memory) {
302+
const llama_memory_i * memory,
303+
bool embd_all) {
303304
clear();
304305

305306
batch = batch_inp;
@@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
378379
}
379380

380381
if (!batch.logits) {
381-
// by default return the output only for the last token
382-
output.resize(batch.n_tokens);
383-
output[output.size() - 1] = true;
382+
if (embd_all) {
383+
// return the output for all tokens
384+
output.resize(batch.n_tokens, true);
385+
} else {
386+
// return the output only for the last token
387+
output.resize(batch.n_tokens, false);
388+
output[output.size() - 1] = true;
389+
}
390+
384391
batch.logits = output.data();
392+
} else if (embd_all) {
393+
bool warn = false;
394+
395+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
396+
if (batch.logits[i] == 0) {
397+
warn = true;
398+
}
399+
}
400+
401+
if (warn) {
402+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403+
404+
output.resize(batch.n_tokens, true);
405+
batch.logits = output.data();
406+
}
385407
}
386408

387409
//

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class llama_batch_allocr {
8888
bool init(
8989
const llama_batch & batch_inp,
9090
const llama_vocab & vocab,
91-
const llama_memory_i * memory);
91+
const llama_memory_i * memory,
92+
bool embd_all);
9293

9394
const llama_batch & get_batch() const;
9495

src/llama-context.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728
}
729729

730730
// note: during encode, we always pass the full sequence starting from pos = 0
731-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
731+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732732
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733733
return -1;
734734
}
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894894
return -1;
895895
}
896896

897-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
897+
// when computing embeddings, all tokens are output
898+
const bool embd_all = cparams.embeddings;
899+
900+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
898901
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
899902
return -1;
900903
}
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911914

912915
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
913916

914-
// this indicates we are doing pooled embedding
915-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916-
917917
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
918918

919-
if (embd_pooled) {
919+
if (embd_all) {
920920
// require that all tokens are output
921921
if (n_outputs_all != n_tokens_all) {
922922
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945945
llama_memory_state_ptr mstate;
946946

947947
while (true) {
948-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
948+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949949
if (!mstate) {
950950
return -2;
951951
}
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581058
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591059
//}
10601060

1061-
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1061+
auto * t_logits = res->get_logits();
10621062
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
10631063

10641064
if (t_embd && res->get_embd_pooled()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221222
const auto n_vocab = vocab.n_tokens();
12231223
const auto n_embd = hparams.n_embd;
12241224

1225-
// TODO: use a per-batch flag for logits presence instead
1226-
bool has_logits = !cparams.embeddings;
1227-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225+
bool has_logits = true;
1226+
bool has_embd = cparams.embeddings;
12281227

12291228
// TODO: hacky enc-dec support
12301229
if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
20442043

20452044
n_queued_tokens += n_tokens_all;
20462045

2047-
// this indicates we are doing pooled embedding
2048-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049-
20502046
embd_seq.clear();
20512047

20522048
uint32_t n_outputs_all = n_tokens_all;
20532049

2054-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
2050+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
20552051
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20562052
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20572053
break;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
99-
GGML_UNUSED(embd_pooled);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99+
GGML_UNUSED(embd_all);
100100

101101
// first try simple split
102102
do {

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3434
llama_memory_state_ptr init_batch(
3535
const llama_batch & batch,
3636
uint32_t n_ubatch,
37-
bool embd_pooled) override;
37+
bool embd_all) override;
3838

3939
llama_memory_state_ptr init_full() override;
4040

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
310310
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311311
const llama_batch & batch,
312312
uint32_t n_ubatch,
313-
bool embd_pooled) {
314-
GGML_UNUSED(embd_pooled);
313+
bool embd_all) {
314+
GGML_UNUSED(embd_all);
315315

316316
do {
317317
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

src/llama-kv-cache-unified.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
5959
llama_memory_state_ptr init_batch(
6060
const llama_batch & batch,
6161
uint32_t n_ubatch,
62-
bool embd_pooled) override;
62+
bool embd_all) override;
6363

6464
llama_memory_state_ptr init_full() override;
6565

src/llama-memory-recurrent.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359359
return result;
360360
}
361361

362-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
363-
GGML_UNUSED(embd_pooled);
364-
362+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
365363
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
366364

367365
std::vector<llama_ubatch> ubatches;
368366

369367
while (sbatch.n_tokens > 0) {
370368
llama_ubatch ubatch;
371369

372-
if (embd_pooled) {
373-
// Pooled embeddings cannot be split across ubatches (yet)
370+
if (embd_all) {
371+
// if all tokens are output, split by sequence
374372
ubatch = sbatch.split_seq(n_ubatch);
375373
} else {
376374
ubatch = sbatch.split_equal(n_ubatch);

0 commit comments

Comments
 (0)