Skip to content

Commit 79c137f

Browse files
authored
examples : allow extracting embeddings from decoder contexts (#13797)
ggml-ci
1 parent 2222931 commit 79c137f

File tree

4 files changed

+10
-16
lines changed

4 files changed

+10
-16
lines changed

examples/embedding/embedding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4141

4242
// run model
4343
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
44-
if (llama_encode(ctx, batch) < 0) {
45-
LOG_ERR("%s : failed to encode\n", __func__);
44+
if (llama_decode(ctx, batch) < 0) {
45+
LOG_ERR("%s : failed to process\n", __func__);
4646
}
4747

4848
for (int i = 0; i < batch.n_tokens; i++) {

examples/retrieval/retrieval.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
8181
}
8282
}
8383

84-
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
84+
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
8585
// clear previous kv_cache values (irrelevant for embeddings)
8686
llama_kv_self_clear(ctx);
8787

8888
// run model
8989
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
90-
if (llama_encode(ctx, batch) < 0) {
91-
LOG_ERR("%s : failed to encode\n", __func__);
90+
if (llama_decode(ctx, batch) < 0) {
91+
LOG_ERR("%s : failed to process\n", __func__);
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
233233
// encode if at capacity
234234
if (batch.n_tokens + n_toks > n_batch) {
235235
float * out = emb + p * n_embd;
236-
batch_encode(ctx, batch, out, s, n_embd);
236+
batch_process(ctx, batch, out, s, n_embd);
237237
common_batch_clear(batch);
238238
p += s;
239239
s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
246246

247247
// final batch
248248
float * out = emb + p * n_embd;
249-
batch_encode(ctx, batch, out, s, n_embd);
249+
batch_process(ctx, batch, out, s, n_embd);
250250

251251
// save embeddings to chunks
252252
for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267267
batch_add_seq(query_batch, query_tokens, 0);
268268

269269
std::vector<float> query_emb(n_embd, 0);
270-
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
270+
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
271271

272272
common_batch_clear(query_batch);
273273

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ int llama_context::encode(llama_batch & inp_batch) {
852852

853853
int llama_context::decode(llama_batch & inp_batch) {
854854
if (!memory) {
855-
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
855+
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
856856
return encode(inp_batch);
857857
}
858858

tools/server/server.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3394,13 +3394,7 @@ struct server_context {
33943394
batch.logits + i,
33953395
};
33963396

3397-
int ret = 0;
3398-
3399-
if (do_encode) {
3400-
ret = llama_encode(ctx, batch_view);
3401-
} else {
3402-
ret = llama_decode(ctx, batch_view);
3403-
}
3397+
const int ret = llama_decode(ctx, batch_view);
34043398

34053399
metrics.on_decoded(slots);
34063400

0 commit comments

Comments
 (0)