Skip to content

Commit 0c1df14

Browse files
authored
server : fix pooled embedding output (#14645)
1 parent b3ad3a0 commit 0c1df14

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tools/server/server.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,25 +2581,27 @@ struct server_context {
25812581
continue;
25822582
}
25832583

2584-
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
2585-
if (embd == NULL) {
2584+
const float * embd = nullptr;
2585+
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
25862586
embd = llama_get_embeddings_ith(ctx, i);
2587+
} else {
2588+
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
25872589
}
25882590

2589-
if (embd == NULL) {
2591+
if (embd == nullptr) {
25902592
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
25912593

25922594
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
25932595
continue;
25942596
}
25952597

25962598
// normalize only when there is pooling
2597-
// TODO: configurable
25982599
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
25992600
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
26002601
res->embedding.push_back(embd_res);
2602+
break;
26012603
} else {
2602-
res->embedding.push_back({ embd, embd + n_embd });
2604+
res->embedding.emplace_back(embd, embd + n_embd);
26032605
}
26042606
}
26052607

0 commit comments

Comments
 (0)