Skip to content

Commit 3c91015

Browse files
noooopminpeter
authored andcommitted
[New Model]: Support Qwen3 Embedding & Reranker (vllm-project#19260)
Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent fd0ac82 commit 3c91015

File tree

8 files changed

+396
-19
lines changed

8 files changed

+396
-19
lines changed

docs/models/supported_models.md

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -387,18 +387,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling
387387

388388
Specified using `--task embed`.
389389

390-
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
391-
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|
392-
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
393-
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
394-
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
395-
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. || |
396-
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. |||
397-
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. |||
398-
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. |||
399-
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
400-
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
401-
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
390+
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
391+
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|
392+
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
393+
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
394+
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
395+
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. || |
396+
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. |||
397+
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. |||
398+
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. |||
399+
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
400+
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
401+
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
402+
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
402403

403404
!!! note
404405
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
@@ -450,12 +451,19 @@ If your model is not in the above list, we will try to automatically convert the
450451

451452
Specified using `--task score`.
452453

453-
| Architecture | Models | Example HF Models |
454-
|---------------------------------------|-------------------|----------------------------------------------|
455-
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
456-
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
457-
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
454+
| Architecture | Models | Example HF Models |
455+
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|
456+
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
457+
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. |
458+
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
459+
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
458460

461+
!!! note
462+
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
463+
464+
```bash
465+
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
466+
```
459467
[](){ #supported-mm-models }
460468

461469
## List of Multimodal Language Models
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
from vllm import LLM
6+
7+
model_name = "Qwen/Qwen3-Reranker-0.6B"
8+
9+
# What is the difference between the official original version and one
10+
# that has been converted into a sequence classification model?
11+
# Qwen3-Reranker is a language model that doing reranker by using the
12+
# logits of "no" and "yes" tokens.
13+
# It needs to computing 151669 tokens logits, making this method extremely
14+
# inefficient, not to mention incompatible with the vllm score API.
15+
# A method for converting the original model into a sequence classification
16+
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
17+
# Models converted offline using this method can not only be more efficient
18+
# and support the vllm score API, but also make the init parameters more
19+
# concise, for example.
20+
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
21+
22+
# If you want to load the official original version, the init parameters are
23+
# as follows.
24+
25+
model = LLM(
26+
model=model_name,
27+
task="score",
28+
hf_overrides={
29+
"architectures": ["Qwen3ForSequenceClassification"],
30+
"classifier_from_token": ["no", "yes"],
31+
"is_original_qwen3_reranker": True,
32+
},
33+
)
34+
35+
# Why do we need hf_overrides for the official original version:
36+
# vllm converts it to Qwen3ForSequenceClassification when loaded for
37+
# better performance.
38+
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
39+
# to manually route to Qwen3ForSequenceClassification.
40+
# - Then, we will extract the vector corresponding to classifier_from_token
41+
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
42+
# - Third, we will convert these two vectors into one vector. The use of
43+
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
44+
45+
# Please use the query_template and document_template to format the query and
46+
# document for better reranker results.
47+
48+
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
49+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
50+
51+
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
52+
document_template = "<Document>: {doc}{suffix}"
53+
54+
if __name__ == "__main__":
55+
instruction = (
56+
"Given a web search query, retrieve relevant passages that answer the query"
57+
)
58+
59+
queries = [
60+
"What is the capital of China?",
61+
"Explain gravity",
62+
]
63+
64+
documents = [
65+
"The capital of China is Beijing.",
66+
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
67+
]
68+
69+
queries = [
70+
query_template.format(prefix=prefix, instruction=instruction, query=query)
71+
for query in queries
72+
]
73+
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
74+
75+
outputs = model.score(queries, documents)
76+
77+
print([output.outputs.score for output in outputs])

tests/models/language/pooling/test_gte.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
4646
architecture="ModernBertModel",
4747
enable_test=True),
48+
########## Qwen3ForCausalLM
49+
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
50+
architecture="Qwen3ForCausalLM",
51+
dtype="float32",
52+
enable_test=True),
53+
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
54+
architecture="Qwen3ForCausalLM",
55+
dtype="float32",
56+
enable_test=False),
4857
]
4958

5059

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
model_name = "Qwen/Qwen3-Reranker-4B"
5+
6+
text_1 = "What is the capital of France?"
7+
texts_2 = [
8+
"The capital of Brazil is Brasilia.",
9+
"The capital of France is Paris.",
10+
]
11+
12+
13+
def vllm_reranker(model_name):
14+
from vllm import LLM
15+
16+
model = LLM(model=model_name,
17+
task="score",
18+
hf_overrides={
19+
"architectures": ["Qwen3ForSequenceClassification"],
20+
"classifier_from_token": ["no", "yes"],
21+
"is_original_qwen3_reranker": True,
22+
},
23+
dtype="float32")
24+
25+
text_1 = "What is the capital of France?"
26+
texts_2 = [
27+
"The capital of Brazil is Brasilia.",
28+
"The capital of France is Paris.",
29+
]
30+
31+
outputs = model.score(text_1, texts_2)
32+
33+
return [output.outputs.score for output in outputs]
34+
35+
36+
def hf_reranker(model_name):
37+
import torch
38+
from transformers import AutoModelForCausalLM, AutoTokenizer
39+
40+
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
41+
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
42+
43+
token_false_id = tokenizer.convert_tokens_to_ids("no")
44+
token_true_id = tokenizer.convert_tokens_to_ids("yes")
45+
46+
max_length = 8192
47+
48+
def process_inputs(pairs):
49+
inputs = tokenizer(pairs,
50+
padding=False,
51+
truncation='longest_first',
52+
return_attention_mask=False,
53+
max_length=max_length)
54+
for i, ele in enumerate(inputs['input_ids']):
55+
inputs['input_ids'][i] = ele
56+
inputs = tokenizer.pad(inputs,
57+
padding=True,
58+
return_tensors="pt",
59+
max_length=max_length)
60+
for key in inputs:
61+
inputs[key] = inputs[key].to(model.device)
62+
return inputs
63+
64+
@torch.no_grad()
65+
def compute_logits(inputs, **kwargs):
66+
batch_scores = model(**inputs).logits[:, -1, :]
67+
true_vector = batch_scores[:, token_true_id]
68+
false_vector = batch_scores[:, token_false_id]
69+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
70+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
71+
scores = batch_scores[:, 1].exp().tolist()
72+
return scores
73+
74+
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
75+
inputs = process_inputs(pairs)
76+
scores = compute_logits(inputs)
77+
78+
return scores
79+
80+
81+
@pytest.mark.parametrize("model_name", [model_name])
82+
def test_model(model_name):
83+
hf_outputs = hf_reranker(model_name)
84+
vllm_outputs = vllm_reranker(model_name)
85+
86+
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
87+
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
5+
6+
text_1 = "What is the capital of France?"
7+
texts_2 = [
8+
"The capital of Brazil is Brasilia.",
9+
"The capital of France is Paris.",
10+
]
11+
12+
13+
def vllm_reranker(model_name):
14+
from vllm import LLM
15+
16+
model = LLM(model=model_name, task="score")
17+
outputs = model.score(text_1, texts_2)
18+
19+
return [output.outputs.score for output in outputs]
20+
21+
22+
def hf_reranker(model_name):
23+
import torch
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
25+
26+
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
27+
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
28+
29+
token_false_id = tokenizer.convert_tokens_to_ids("no")
30+
token_true_id = tokenizer.convert_tokens_to_ids("yes")
31+
32+
max_length = 8192
33+
34+
def process_inputs(pairs):
35+
inputs = tokenizer(pairs,
36+
padding=False,
37+
truncation='longest_first',
38+
return_attention_mask=False,
39+
max_length=max_length)
40+
for i, ele in enumerate(inputs['input_ids']):
41+
inputs['input_ids'][i] = ele
42+
inputs = tokenizer.pad(inputs,
43+
padding=True,
44+
return_tensors="pt",
45+
max_length=max_length)
46+
for key in inputs:
47+
inputs[key] = inputs[key].to(model.device)
48+
return inputs
49+
50+
@torch.no_grad()
51+
def compute_logits(inputs, **kwargs):
52+
batch_scores = model(**inputs).logits[:, -1, :]
53+
true_vector = batch_scores[:, token_true_id]
54+
false_vector = batch_scores[:, token_false_id]
55+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
56+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
57+
scores = batch_scores[:, 1].exp().tolist()
58+
return scores
59+
60+
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
61+
inputs = process_inputs(pairs)
62+
scores = compute_logits(inputs)
63+
64+
return scores
65+
66+
67+
@pytest.mark.parametrize("model_name", [model_name])
68+
def test_model(model_name):
69+
hf_outputs = hf_reranker(model_name)
70+
vllm_outputs = vllm_reranker(model_name)
71+
72+
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
73+
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def check_available_online(
238238
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
239239
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
240240
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
241+
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
241242
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
242243
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
243244
v0_only=True),

0 commit comments

Comments
 (0)