Skip to content

[New Model]: Support Qwen3 Embedding & Reranker #19260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added test/__init__.py
Empty file.
80 changes: 80 additions & 0 deletions test/st.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def format_instruction(instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
instruction=instruction, query=query, doc=doc)
return output


def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length - len(prefix_tokens) -
len(suffix_tokens))
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs


def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores


model_name = "Qwen/Qwen3-Reranker-4B"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
max_length = 8192

prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)

if __name__ == '__main__':

task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
"What is the capital of China?",
"Explain gravity",
]

documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

pairs = [
format_instruction(task, query, doc)
for query, doc in zip(queries, documents)
]

# Tokenize the input texts
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

print("scores: ", scores)
53 changes: 53 additions & 0 deletions test/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")


def format_instruction(instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
instruction=instruction, query=query, doc=doc)
return output


max_length = 8192

prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

instruction = 'Given a web search query, retrieve relevant passages that answer the query'


def process_inputs(query, doc):
messages = format_instruction(instruction, query, doc)
messages = prefix + messages + suffix
return messages


if __name__ == '__main__':
from vllm import LLM

model = LLM(model="Qwen/Qwen3-Reranker-4B",
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"]
})

queries = [
"What is the capital of China?",
"Explain gravity",
]

documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

outputs = model.score(queries, documents, process_inputs=process_inputs)

print([output.outputs.score for output in outputs])
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
########## Qwen3ForCausalLM
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=True),
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
architecture="Qwen3ForCausalLM",
enable_test=True),
]


Expand Down
86 changes: 86 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "Qwen/Qwen3-Reranker-4B"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"]
},
dtype="float32")

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]

outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
21 changes: 13 additions & 8 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ def _cross_encoding_score(
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
process_inputs: Optional[Callable] = None,
) -> list[ScoringRequestOutput]:

if isinstance(tokenizer, MistralTokenizer):
Expand All @@ -1150,12 +1151,17 @@ def _cross_encoding_score(
parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
if process_inputs is not None:
text = process_inputs(q, t)
prompt_inputs = tokenizer(text=text, **tokenization_kwargs)
else:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

parsed_prompts.append(engine_prompt)

self._validate_and_add_requests(
Expand All @@ -1182,6 +1188,7 @@ def score(
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
process_inputs: Optional[Callable] = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>`.

Expand Down Expand Up @@ -1258,11 +1265,9 @@ def ensure_str(prompt: SingletonPrompt):
_validate_score_input_lens(input_text_1, input_text_2)

if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
return self._cross_encoding_score(
tokenizer, input_text_1, input_text_2, truncate_prompt_tokens,
use_tqdm, lora_request, prompt_adapter_request, process_inputs)
else:
return self._embedding_score(
tokenizer,
Expand Down
Loading