Skip to content

Commit 62b2b82

Browse files
committed
server : output embeddings for all tokens when pooling = none
ggml-ci
1 parent d58f8a1 commit 62b2b82

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

examples/server/server.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,13 +727,21 @@ struct server_task_result_cmpl_partial : server_task_result {
727727

728728
struct server_task_result_embd : server_task_result {
729729
int index = 0;
730-
std::vector<float> embedding;
730+
std::vector<std::vector<float>> embedding;
731731

732732
virtual int get_index() override {
733733
return index;
734734
}
735735

736736
virtual json to_json() override {
737+
if (embedding.size() == 1){
738+
// to be OAI compatible
739+
return json {
740+
{"index", index},
741+
{"embedding", embedding[0]},
742+
};
743+
}
744+
737745
return json {
738746
{"index", index},
739747
{"embedding", embedding},
@@ -2030,12 +2038,12 @@ struct server_context {
20302038
if (embd == NULL) {
20312039
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
20322040

2033-
res->embedding = std::vector<float>(n_embd, 0.0f);
2041+
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
20342042
continue;
20352043
}
20362044

20372045
common_embd_normalize(embd, embd_res.data(), n_embd);
2038-
res->embedding = embd_res;
2046+
res->embedding.push_back(embd_res);
20392047
}
20402048

20412049
SLT_DBG(slot, "%s", "sending embeddings\n");
@@ -2648,7 +2656,10 @@ struct server_context {
26482656

26492657
// add prompt tokens for processing in the current batch
26502658
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2651-
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
2659+
// without pooling, we want to output the embeddings for all the tokens in the batch
2660+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
2661+
2662+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
26522663

26532664
if (slot.params.cache_prompt) {
26542665
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);

examples/server/tests/unit/test_embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ def test_embedding_multiple():
4545
assert len(d['embedding']) > 1
4646

4747

48+
def test_embedding_pooling_none():
49+
server = ServerPreset.bert_bge_small(pooling = 'none')
50+
server.start()
51+
res = server.make_request("POST", "/embeddings", data={
52+
"input": "hello hello hello",
53+
})
54+
assert res.status_code == 200
55+
assert len(res.body['data']) == 1
56+
assert 'embedding' in res.body['data'][0]
57+
assert len(res.body['data'][0]['embedding']) == 3
58+
59+
4860
def test_embedding_openai_library_single():
4961
global server
5062
server.start()

examples/server/tests/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class ServerProcess:
6565
server_reranking: bool | None = False
6666
server_metrics: bool | None = False
6767
server_slots: bool | None = False
68+
pooling: str | None = None
6869
draft: int | None = None
6970
api_key: str | None = None
7071
response_format: str | None = None
@@ -132,6 +133,8 @@ def start(self, timeout_seconds: int = 10) -> None:
132133
server_args.append("--metrics")
133134
if self.server_slots:
134135
server_args.append("--slots")
136+
if self.pooling:
137+
server_args.extend(["--pooling", self.pooling])
135138
if self.model_alias:
136139
server_args.extend(["--alias", self.model_alias])
137140
if self.n_ctx:
@@ -272,7 +275,7 @@ def tinyllama2() -> ServerProcess:
272275
return server
273276

274277
@staticmethod
275-
def bert_bge_small() -> ServerProcess:
278+
def bert_bge_small(pooling = 'last') -> ServerProcess:
276279
server = ServerProcess()
277280
server.model_hf_repo = "ggml-org/models"
278281
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
@@ -283,6 +286,7 @@ def bert_bge_small() -> ServerProcess:
283286
server.n_slots = 2
284287
server.seed = 42
285288
server.server_embeddings = True
289+
server.pooling = pooling
286290
return server
287291

288292
@staticmethod

0 commit comments

Comments
 (0)