Skip to content

llama : support qwen3 rerank and embeddings #14029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,12 @@ struct common_init_result common_init_from_params(common_params & params) {

bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
bool has_rerank_prompt = llama_model_chat_template(model, "rerank_prefix") != NULL ||
llama_model_chat_template(model, "rerank_suffix") != NULL;

if (!has_eos && !has_sep) {
if (has_rerank_prompt) {
// OK, do nothing
} else if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
Expand Down
70 changes: 70 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3061,6 +3064,73 @@ def prepare_tensors(self):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3

# extra logic for rerank models
token_false_id: int | None = None
token_true_id: int | None = None
sep_token_id: int = 0
is_tied_embeddings: bool = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()

def _find_rerank_config(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
logger.info(f"gguf: token_false_id = {self.token_false_id}, token_true_id = {self.token_true_id}")
logger.info(f"gguf: sep_token_id = {self.sep_token_id}")
logger.info(f"gguf: is_tied_embeddings = {self.is_tied_embeddings}")

def set_gguf_parameters(self):
super().set_gguf_parameters()
is_rerank = self.token_false_id is not None and self.token_true_id is not None
if is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank_prefix",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n",
}, {
"name": "rerank_suffix",
"template": "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
}])

def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor
assert self.token_false_id is not None and self.token_true_id is not None
false_row = data_torch[self.token_false_id]
true_row = data_torch[self.token_true_id]
return torch.stack([true_row, false_row], dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_rerank = self.token_false_id is not None and self.token_true_id is not None

if not name.startswith("model."):
name = "model." + name

if is_rerank:
if self.is_tied_embeddings and "embed_tokens" in name:
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch)),
(self.map_tensor_name(name), data_torch),
]
if not self.is_tied_embeddings and "lm_head" in name:
# this is the lm_head tensor, we need to extract the cls_out tensor
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch))]

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
]


Expand Down
3 changes: 2 additions & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template." }, // FIXME: cannot add %s because it will be replaced by arch name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you can use LLM_KV_TOKENIZER_CHAT_TEMPLATE with suffix:

llama.cpp/src/llama-arch.cpp

Lines 1722 to 1725 in c02f53d

if (suffix != nullptr) {
name += ".";
name += suffix;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing this but it doesn't work, maybe it's buggy somewhere else:

    const auto key = name
        ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking in the wrong place, this is where it's broken:

llama.cpp/src/llama-arch.cpp

Lines 1709 to 1712 in e83ba3e

std::string LLM_KV::operator()(llm_kv kv) const {
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #14050

{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
Expand Down Expand Up @@ -629,6 +629,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_CLS_OUT, "cls.output" }, // rerank
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
Expand Down
63 changes: 37 additions & 26 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
if (!cparams.embeddings) {
return;
}

const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token

if (is_last_tok) {
// set output to the last token of each sequence
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
Expand All @@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));

std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);

for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];

// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");

for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];

if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
}
}
}
}

if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}

} else {
// set output to first token of each sequence
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
Expand All @@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));

std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);

for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];

// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");

for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];

if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
}
}

for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
}

Expand Down Expand Up @@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
}

ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);

auto & cur = inp->cls;

Expand Down Expand Up @@ -1577,10 +1583,15 @@ void llm_graph_context::build_pooling(
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
}
} else if (cls_out) {
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
GGML_ASSERT(cls_out_b != nullptr);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
if (arch == LLM_ARCH_QWEN3) {
cur = ggml_mul_mat(ctx0, cls_out, inp);
cur = ggml_soft_max(ctx0, cur); // qwen3 uses softmax on the output
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov I think there is a bug with build_inp_cls(). It suppose to contain only indexes of the output tokens (last token), but in this case, it actually contains all tokens. This make the output score to be incorrect atm as it returns the score for first token. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can make quick fix for now like this, similar to build_bert():

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index afef84870..8b11197df 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -7043,7 +7043,7 @@ struct llm_build_qwen3 : public llm_graph_context {
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
-            if (il == n_layer - 1) {
+            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
                 // skip computing output for unused tokens
                 ggml_tensor * inp_out_ids = build_inp_out_ids();
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that still doesn't work as I expected.

For example, if my sequence has only one output token, then I expect the inp tensor here to have shape [n_embd, 1], but in reality, it has shape [n_embd, n_tokens]

Maybe I misunderstood something here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok I think I got it. The main problem is that qwen's rerank model use causal attention, it's simply a normal next generation model which outputs either "yes" or "no" token

I think the assumption in llama.cpp is that CLS and RANK are non-causal, hence only the first token is marked as output

Not sure what's the best way to support this though

Copy link
Collaborator Author

@ngxson ngxson Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK found a hack around this, for Qwen3, I force the position to last (only the position, not the pooling) in 030dc3b

Probably we should separate the notion of "pooling" and "output position" in the future

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the assumption in llama.cpp is that CLS and RANK are non-causal, hence only the first token is marked as output

The idea is that the llm_build_ functions will compute the embeddings for all tokens in the batch. The notion of "output ids" is purely an optimization trick to avoid unnecessary computation in the last layer and when doing any kind of pooling, it should generally be disabled.

For Qwen3 rerank, what you seem to need is to pool using last and apply the classification head on the result - the latter is missing, so it has to be added. We just haven't encountered models with pooling last and a classification head at the same time.

And it seems we should remove LLAMA_POOLING_TYPE_RANK - it's a bit redundant. Instead CLS and LAST should do the same thing - i.e. apply a classification head if there is one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok I got it. The problem is that I don't have much time for the rest of the day. Do you think we can clean this up in a follow up PR?

And it seems we should remove LLAMA_POOLING_TYPE_RANK - it's a bit redundant. Instead CLS and LAST should do the same thing - i.e. apply a classification head if there is one.

I think having the notion of LLAMA_TASK_* would be useful. For example, pooling CLS can be used for task type CLS and RANK. This can also be useful to block certain endpoints. For example, rerank model should only support /rerank and not /embeddings or /completion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's better to take the time and make it right, no need to merge it now.

} else {
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
GGML_ASSERT(cls_out_b != nullptr);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
}
} else {
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
}
Expand Down
3 changes: 2 additions & 1 deletion src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ class llm_graph_input_mean : public llm_graph_input_i {

class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : arch(arch), cparams(cparams) {}
virtual ~llm_graph_input_cls() = default;

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * cls; // I32 [n_batch]

const llm_arch arch;
const llama_cparams & cparams;
};

Expand Down
10 changes: 8 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_QWEN3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);

switch (hparams.n_layer) {
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break;
Expand Down Expand Up @@ -2468,6 +2470,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output rerank
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
Expand Down Expand Up @@ -7057,7 +7062,7 @@ struct llm_build_qwen3 : public llm_graph_context {
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}

if (il == n_layer - 1) {
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
Expand Down Expand Up @@ -13788,7 +13793,8 @@ uint64_t llama_model_size(const llama_model * model) {
}

const char * llama_model_chat_template(const llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
const auto key = name
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name)
Comment on lines +13796 to +13797
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto key = name
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name)
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)

I wonder how long this has been broken?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, has never worked it seems, broken since it was introduced in #11016

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in used by any of the examples so we don't know if it works in the first place (probably used in downstream project, but idk)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a PR.

: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
Expand Down
2 changes: 1 addition & 1 deletion tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4715,7 +4715,7 @@ int main(int argc, char ** argv) {
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
auto tmp = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
Expand Down
49 changes: 35 additions & 14 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,23 +260,44 @@ static size_t validate_utf8(const std::string& text) {
// template utils
//

// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
// format rerank task:
// - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
// - using prompt: <rerank_prefix>query<rerank_suffix>doc
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_tokens result;

// Get EOS token - use SEP token as fallback if EOS is not available
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}
if (llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL) {
// Get EOS token - use SEP token as fallback if EOS is not available
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}

result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
result.push_back(eos_token);
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(eos_token);
} else {
// using prompt template
const char * prefix = llama_model_chat_template(model, "rerank_prefix");
const char * suffix = llama_model_chat_template(model, "rerank_suffix");

if (prefix == NULL && suffix == NULL) {
throw std::runtime_error("Rerank prompt template not found in the model\n");
}

result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
result.push_back(eos_token);
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(eos_token);
const llama_tokens prefix_tokens = prefix ? common_tokenize(vocab, prefix, true, false) : llama_tokens();
const llama_tokens suffix_tokens = suffix ? common_tokenize(vocab, suffix, false, false) : llama_tokens();
result.reserve(prefix_tokens.size() + query.size() + suffix_tokens.size() + doc.size());
result.insert(result.end(), prefix_tokens.begin(), prefix_tokens.end());
result.insert(result.end(), query.begin(), query.end());
result.insert(result.end(), suffix_tokens.begin(), suffix_tokens.end());
result.insert(result.end(), doc.begin(), doc.end());
}

return result;
}
Expand Down
Loading