Skip to content

Commit f5ba6ba

Browse files
[Model] Qwen3 OpenAI API Server reranking and scoring backward compatible and support instruction
Signed-off-by: ElShaddollConstruct <1481345518@qq.com>
1 parent 7b1895e commit f5ba6ba

File tree

3 files changed

+218
-25
lines changed

3 files changed

+218
-25
lines changed

tests/entrypoints/openai/test_score.py

Lines changed: 134 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from torch import tensor
99

10-
from vllm.entrypoints.openai.protocol import ScoreResponse
10+
from vllm.entrypoints.openai.protocol import RerankResponse, ScoreResponse
1111

1212
from ...utils import RemoteOpenAIServer
1313

@@ -29,11 +29,35 @@ def v1(run_with_both_engines):
2929
"name": "BAAI/bge-base-en-v1.5",
3030
"is_cross_encoder": False
3131
},
32+
{
33+
"name": "Qwen/Qwen3-Reranker-0.6B",
34+
"is_cross_encoder": True,
35+
"is_qwen3_reranker": True,
36+
},
3237
]
3338
DTYPE = "half"
3439

3540

41+
def _run_qwen3_reranker_hf(hf_model, text_pairs, instruction):
42+
"""Helper to run Qwen3 reranker with HF, applying the template."""
43+
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
44+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
45+
46+
formatted_pairs = []
47+
for query, doc in text_pairs:
48+
q_formatted = f"{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
49+
d_formatted = f"<Document>: {doc}{suffix}"
50+
formatted_pairs.append([q_formatted, d_formatted])
51+
52+
return hf_model.predict(formatted_pairs).tolist()
53+
54+
3655
def run_transformers(hf_model, model, text_pairs):
56+
if model.get("is_qwen3_reranker"):
57+
# The default instruction used in the server fixture.
58+
default_instruction = "Given a web search query, retrieve relevant passages that answer the query"
59+
return _run_qwen3_reranker_hf(hf_model, text_pairs,
60+
default_instruction)
3761
if model["is_cross_encoder"]:
3862
return hf_model.predict(text_pairs).tolist()
3963
else:
@@ -53,21 +77,51 @@ def model(request):
5377

5478
@pytest.fixture(scope="class")
5579
def server(model: dict[str, Any]):
56-
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
80+
args = ["--enforce-eager", "--max-model-len", "256", "--dtype", DTYPE]
81+
if model.get("is_qwen3_reranker"):
82+
import json
83+
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
84+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
85+
default_instruction = "Given a web search query, retrieve relevant passages that answer the query"
86+
87+
hf_overrides = {
88+
"architectures": ["Qwen3ForSequenceClassification"],
89+
"classifier_from_token": ["no", "yes"],
90+
"is_original_qwen3_reranker": True,
91+
"score_template": {
92+
"query_template":
93+
f"{prefix}<Instruct>: {{instruction}}\n<Query>: {{query}}\n",
94+
"document_template": f"<Document>: {{document}}{suffix}",
95+
"default_context": {
96+
"instruction": default_instruction
97+
}
98+
}
99+
}
100+
args.extend(["--hf-overrides", json.dumps(hf_overrides)])
57101

58102
with RemoteOpenAIServer(model["name"], args) as remote_server:
59103
yield remote_server
60104

61105

62106
@pytest.fixture(scope="class")
63107
def runner(model: dict[str, Any], hf_runner):
64-
kwargs = {
65-
"dtype": DTYPE,
66-
"is_cross_encoder" if model["is_cross_encoder"]\
67-
else "is_sentence_transformer": True
68-
}
108+
model_name = model["name"]
109+
kwargs = {"dtype": DTYPE}
110+
if model.get("is_qwen3_reranker"):
111+
# For the HF reference, use the pre-converted Sequence Classification
112+
# model to simplify the runner logic.
113+
model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
114+
hf_runner_kwargs = {
115+
"dtype": DTYPE,
116+
"is_cross_encoder": True,
117+
"trust_remote_code": True,
118+
}
119+
elif model["is_cross_encoder"]:
120+
hf_runner_kwargs = {"dtype": DTYPE, "is_cross_encoder": True}
121+
else:
122+
hf_runner_kwargs = {"dtype": DTYPE, "is_sentence_transformer": True}
69123

70-
with hf_runner(model["name"], **kwargs) as hf_model:
124+
with hf_runner(model_name, **hf_runner_kwargs) as hf_model:
71125
yield hf_model
72126

73127

@@ -191,3 +245,75 @@ def test_score_max_model_len(self, server: RemoteOpenAIServer,
191245
assert score_response.status_code == 400
192246
assert "Please, select a smaller truncation size." in \
193247
score_response.text
248+
249+
def test_rerank_with_template(self, server: RemoteOpenAIServer,
250+
model: dict[str, Any], runner):
251+
if not model.get("is_qwen3_reranker"):
252+
pytest.skip("Test only for Qwen3 Reranker with template support.")
253+
254+
instruction = "Find the document that is most relevant to the query about national capitals."
255+
query = "What is the capital of China?"
256+
documents = [
257+
"The capital of France is Paris.",
258+
"The capital of China is Beijing."
259+
]
260+
261+
# vLLM run with custom instruction via kwargs
262+
rerank_response = requests.post(
263+
server.url_for("rerank"),
264+
json={
265+
"model": model["name"],
266+
"query": query,
267+
"documents": documents,
268+
"score_template_kwargs": {
269+
"instruction": instruction
270+
}
271+
})
272+
rerank_response.raise_for_status()
273+
response_data = RerankResponse.model_validate(rerank_response.json())
274+
vllm_outputs = {
275+
res.document.text: res.relevance_score
276+
for res in response_data.results
277+
}
278+
279+
# HF reference run with the same custom instruction
280+
text_pairs = [[query, doc] for doc in documents]
281+
hf_outputs = _run_qwen3_reranker_hf(runner, text_pairs, instruction)
282+
283+
for i, doc in enumerate(documents):
284+
assert vllm_outputs[doc] == pytest.approx(hf_outputs[i],
285+
rel=0.01)
286+
287+
def test_score_with_template(self, server: RemoteOpenAIServer,
288+
model: dict[str, Any], runner):
289+
if not model.get("is_qwen3_reranker"):
290+
pytest.skip("Test only for Qwen3 Reranker with template support.")
291+
292+
instruction = "Find the document that is most relevant to the query about national capitals."
293+
text_1 = "What is the capital of China?"
294+
text_2 = [
295+
"The capital of France is Paris.",
296+
"The capital of China is Beijing."
297+
]
298+
299+
# vLLM run with custom instruction via kwargs
300+
score_response = requests.post(
301+
server.url_for("score"),
302+
json={
303+
"model": model["name"],
304+
"text_1": text_1,
305+
"text_2": text_2,
306+
"score_template_kwargs": {
307+
"instruction": instruction
308+
}
309+
})
310+
score_response.raise_for_status()
311+
response_data = ScoreResponse.model_validate(score_response.json())
312+
vllm_outputs = [res.score for res in response_data.data]
313+
314+
# HF reference run with the same custom instruction
315+
text_pairs = [[text_1, doc] for doc in text_2]
316+
hf_outputs = _run_qwen3_reranker_hf(runner, text_pairs, instruction)
317+
318+
for i in range(len(vllm_outputs)):
319+
assert vllm_outputs[i] == pytest.approx(hf_outputs[i], rel=0.01)

vllm/entrypoints/openai/protocol.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,16 @@ class ScoreRequest(OpenAIBaseModel):
11941194
"default: 0). Any priority other than 0 will raise an error "
11951195
"if the served model does not use priority scheduling."),
11961196
)
1197-
1197+
score_template: Optional[dict[str, str]] = Field(
1198+
default=None,
1199+
description=("A dictionary containing query_template and "
1200+
"document_template to format the scorer input."))
1201+
score_template_kwargs: Optional[dict[str, Any]] = Field(
1202+
default=None,
1203+
description=(
1204+
"Additional keyword args to pass to the template renderer. "
1205+
"Will be accessible by the score template."),
1206+
)
11981207
# --8<-- [end:score-extra-params]
11991208

12001209
def to_pooling_params(self):
@@ -1220,7 +1229,16 @@ class RerankRequest(OpenAIBaseModel):
12201229
"default: 0). Any priority other than 0 will raise an error "
12211230
"if the served model does not use priority scheduling."),
12221231
)
1223-
1232+
rerank_template: Optional[dict[str, str]] = Field(
1233+
default=None,
1234+
description=("A dictionary containing query_template and "
1235+
"document_template to format the reranker input.")
1236+
)
1237+
rerank_template_kwargs: Optional[dict[str, Any]] = Field(
1238+
default=None,
1239+
description=("A dictionary of key-value pairs to be formatted into "
1240+
"the rerank model's template.")
1241+
)
12241242
# --8<-- [end:rerank-extra-params]
12251243

12261244
def to_pooling_params(self):

vllm/entrypoints/openai/serving_score.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def _embedding_score(
5454
texts_1: list[str],
5555
texts_2: list[str],
5656
request: Union[RerankRequest, ScoreRequest],
57-
request_id=str,
57+
request_id: str,
5858
tokenization_kwargs: Optional[dict[str, Any]] = None,
5959
lora_request: Optional[Union[LoRARequest, None]] = None,
6060
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
@@ -139,31 +139,58 @@ async def _embedding_score(
139139

140140
return final_res_batch
141141

142-
async def _cross_encoding_score(
142+
async def _preprocess_score(
143143
self,
144-
tokenizer: Union[AnyTokenizer],
144+
request: Union[RerankRequest, ScoreRequest],
145+
tokenizer: AnyTokenizer,
145146
texts_1: list[str],
146147
texts_2: list[str],
147-
request: Union[RerankRequest, ScoreRequest],
148-
request_id=str,
149148
tokenization_kwargs: Optional[dict[str, Any]] = None,
150-
lora_request: Optional[Union[LoRARequest, None]] = None,
151-
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
152-
None]] = None,
153-
trace_headers: Optional[Mapping[str, str]] = None,
154-
) -> list[PoolingRequestOutput]:
155-
149+
) -> tuple[list[str], list[TokensPrompt]]:
156150
request_prompts: list[str] = []
157151
engine_prompts: list[TokensPrompt] = []
158152

159153
if len(texts_1) == 1:
160154
texts_1 = texts_1 * len(texts_2)
161155

162-
input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)]
156+
def identity_processor(t1: str, t2: str) -> tuple[str, str]:
157+
return t1, t2
163158

164-
if isinstance(tokenizer, MistralTokenizer):
165-
raise ValueError(
166-
"MistralTokenizer not supported for cross-encoding")
159+
pair_processor = identity_processor
160+
161+
template_config = (request.score_template
162+
or self.model_config.hf_config.get(
163+
"score_template"))
164+
165+
if isinstance(template_config, dict):
166+
167+
def template_processor(t1: str, t2: str) -> tuple[str, str]:
168+
default_context = template_config.get("default_context", {})
169+
context = default_context.copy() if isinstance(
170+
default_context, dict) else {}
171+
if request.score_template_kwargs:
172+
context.update(request.score_template_kwargs)
173+
174+
context['query'] = t1
175+
context['document'] = t2
176+
177+
query_template = template_config.get("query_template",
178+
"{query}")
179+
doc_template = template_config.get("document_template",
180+
"{document}")
181+
182+
formatted_t1 = query_template.format(
183+
**context) if "query_template" in template_config else t1
184+
formatted_t2 = doc_template.format(
185+
**context
186+
) if "document_template" in template_config else t2
187+
return formatted_t1, formatted_t2
188+
189+
pair_processor = template_processor
190+
191+
input_pairs = [
192+
pair_processor(t1, t2) for t1, t2 in zip(texts_1, texts_2)
193+
]
167194

168195
tokenize_async = make_async(tokenizer.__call__,
169196
executor=self._tokenizer_executor)
@@ -186,6 +213,28 @@ async def _cross_encoding_score(
186213

187214
request_prompts.append(request_prompt)
188215
engine_prompts.append(engine_prompt)
216+
return request_prompts, engine_prompts
217+
218+
async def _cross_encoding_score(
219+
self,
220+
tokenizer: AnyTokenizer,
221+
texts_1: list[str],
222+
texts_2: list[str],
223+
request: Union[RerankRequest, ScoreRequest],
224+
request_id: str,
225+
tokenization_kwargs: Optional[dict[str, Any]] = None,
226+
lora_request: Optional[Union[LoRARequest, None]] = None,
227+
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
228+
None]] = None,
229+
trace_headers: Optional[Mapping[str, str]] = None,
230+
) -> list[PoolingRequestOutput]:
231+
232+
if isinstance(tokenizer, MistralTokenizer):
233+
raise ValueError(
234+
"MistralTokenizer not supported for cross-encoding")
235+
236+
request_prompts, engine_prompts = await self._preprocess_score(
237+
request, tokenizer, texts_1, texts_2, tokenization_kwargs)
189238

190239
# Schedule the request and get the result generator.
191240
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

0 commit comments

Comments
 (0)