Skip to content

Commit 3f3b9a2

Browse files
committed
model : support qwen3 rerank and embeddings
1 parent 3a07714 commit 3f3b9a2

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,6 +3061,64 @@ def prepare_tensors(self):
30613061
class Qwen3Model(Qwen2Model):
30623062
model_arch = gguf.MODEL_ARCH.QWEN3
30633063

3064+
# extra logic for rerank models
3065+
token_false_id: int | None = None
3066+
token_true_id: int | None = None
3067+
sep_token_id: int = 0
3068+
is_tied_embeddings: bool = False
3069+
3070+
def __init__(self, *args, **kwargs):
3071+
super().__init__(*args, **kwargs)
3072+
# a bit hacky, but currently the only way to detect if this is a rerank model
3073+
readme_path = self.dir_model / "README.md"
3074+
readme_text = ""
3075+
if readme_path.exists():
3076+
with readme_path.open("r", encoding="utf-8") as f:
3077+
readme_text = f.read()
3078+
if "# Qwen3-Reranker" in readme_text:
3079+
self._find_rerank_config()
3080+
3081+
def _find_rerank_config(self):
3082+
from transformers import AutoTokenizer
3083+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
3084+
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
3085+
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
3086+
self.sep_token_id = tokenizer.convert_tokens_to_ids("\\n") # unused, but needed for rerank check
3087+
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
3088+
logger.info(f"gguf: token_false_id = {self.token_false_id}, token_true_id = {self.token_true_id}")
3089+
logger.info(f"gguf: sep_token_id = {self.sep_token_id}")
3090+
logger.info(f"gguf: is_tied_embeddings = {self.is_tied_embeddings}")
3091+
3092+
def set_gguf_parameters(self):
3093+
super().set_gguf_parameters()
3094+
is_rerank = self.token_false_id is not None and self.token_true_id is not None
3095+
if is_rerank:
3096+
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
3097+
self.gguf_writer.add_sep_token_id(self.sep_token_id)
3098+
self.gguf_writer.add_uint32(gguf.Keys.Classifier.OUTPUT_LABELS, 2)
3099+
3100+
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
3101+
# extract "yes" and "no" tokens from the output lm_head tensor
3102+
assert self.token_false_id is not None and self.token_true_id is not None
3103+
false_row = data_torch[self.token_false_id]
3104+
true_row = data_torch[self.token_true_id]
3105+
return torch.stack([true_row, false_row], dim=0)
3106+
3107+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3108+
is_rerank = self.token_false_id is not None and self.token_true_id is not None
3109+
3110+
if is_rerank:
3111+
if self.is_tied_embeddings and "embed_tokens" in name:
3112+
return [
3113+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch)),
3114+
(self.map_tensor_name(name), data_torch),
3115+
]
3116+
if not self.is_tied_embeddings and "lm_head" in name:
3117+
# this is the lm_head tensor, we need to extract the cls_out tensor
3118+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch))]
3119+
3120+
return super().modify_tensors(data_torch, name, bid)
3121+
30643122

30653123
@ModelBase.register("Qwen3MoeForCausalLM")
30663124
class Qwen3MoeModel(Qwen2MoeModel):

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
629629
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
630630
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
631631
{ LLM_TENSOR_OUTPUT, "output" },
632+
{ LLM_TENSOR_CLS_OUT, "cls.output" }, // rerank
632633
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
633634
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
634635
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },

src/llama-graph.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,10 +1577,15 @@ void llm_graph_context::build_pooling(
15771577
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
15781578
}
15791579
} else if (cls_out) {
1580-
// Single layer classification head (direct projection)
1581-
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1582-
GGML_ASSERT(cls_out_b != nullptr);
1583-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1580+
if (arch == LLM_ARCH_QWEN3) {
1581+
cur = ggml_mul_mat(ctx0, cls_out, inp);
1582+
cur = ggml_soft_max(ctx0, cur); // qwen3 uses softmax on the output
1583+
} else {
1584+
// Single layer classification head (direct projection)
1585+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1586+
GGML_ASSERT(cls_out_b != nullptr);
1587+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1588+
}
15841589
} else {
15851590
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
15861591
}

src/llama-model.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
819819
} break;
820820
case LLM_ARCH_QWEN3:
821821
{
822+
// default for embeddings, will be overwritten if model is rerank
823+
hparams.pooling_type = LLAMA_POOLING_TYPE_LAST;
824+
822825
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
826+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
827+
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
828+
823829
switch (hparams.n_layer) {
824830
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
825831
case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break;
@@ -2463,6 +2469,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24632469
{
24642470
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
24652471

2472+
// output rerank
2473+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 2}, TENSOR_NOT_REQUIRED);
2474+
24662475
// output
24672476
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
24682477
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);

0 commit comments

Comments
 (0)