From 0379f79e921862d8d6c4174fc91f13a4b2f7a247 Mon Sep 17 00:00:00 2001 From: BetterAndBetterII Date: Mon, 30 Jun 2025 13:18:11 +0800 Subject: [PATCH] [Model] Qwen3 OpenAI API Server reranking and scoring backward compatible and support instruction Signed-off-by: BetterAndBetterII --- tests/entrypoints/openai/test_score.py | 142 +++++++++++++++++++++-- vllm/entrypoints/openai/protocol.py | 22 +++- vllm/entrypoints/openai/serving_score.py | 79 ++++++++++--- 3 files changed, 218 insertions(+), 25 deletions(-) diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 8927fe771809..188febe81c7c 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import tensor -from vllm.entrypoints.openai.protocol import ScoreResponse +from vllm.entrypoints.openai.protocol import RerankResponse, ScoreResponse from ...utils import RemoteOpenAIServer @@ -29,11 +29,35 @@ def v1(run_with_both_engines): "name": "BAAI/bge-base-en-v1.5", "is_cross_encoder": False }, + { + "name": "Qwen/Qwen3-Reranker-0.6B", + "is_cross_encoder": True, + "is_qwen3_reranker": True, + }, ] DTYPE = "half" +def _run_qwen3_reranker_hf(hf_model, text_pairs, instruction): + """Helper to run Qwen3 reranker with HF, applying the template.""" + prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' + suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + formatted_pairs = [] + for query, doc in text_pairs: + q_formatted = f"{prefix}: {instruction}\n: {query}\n" + d_formatted = f": {doc}{suffix}" + formatted_pairs.append([q_formatted, d_formatted]) + + return hf_model.predict(formatted_pairs).tolist() + + def run_transformers(hf_model, model, text_pairs): + if model.get("is_qwen3_reranker"): + # The default instruction used in the server fixture. + default_instruction = "Given a web search query, retrieve relevant passages that answer the query" + return _run_qwen3_reranker_hf(hf_model, text_pairs, + default_instruction) if model["is_cross_encoder"]: return hf_model.predict(text_pairs).tolist() else: @@ -53,7 +77,27 @@ def model(request): @pytest.fixture(scope="class") def server(model: dict[str, Any]): - args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] + args = ["--enforce-eager", "--max-model-len", "256", "--dtype", DTYPE] + if model.get("is_qwen3_reranker"): + import json + prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' + suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + default_instruction = "Given a web search query, retrieve relevant passages that answer the query" + + hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + "score_template": { + "query_template": + f"{prefix}: {{instruction}}\n: {{query}}\n", + "document_template": f": {{document}}{suffix}", + "default_context": { + "instruction": default_instruction + } + } + } + args.extend(["--hf-overrides", json.dumps(hf_overrides)]) with RemoteOpenAIServer(model["name"], args) as remote_server: yield remote_server @@ -61,13 +105,23 @@ def server(model: dict[str, Any]): @pytest.fixture(scope="class") def runner(model: dict[str, Any], hf_runner): - kwargs = { - "dtype": DTYPE, - "is_cross_encoder" if model["is_cross_encoder"]\ - else "is_sentence_transformer": True - } + model_name = model["name"] + kwargs = {"dtype": DTYPE} + if model.get("is_qwen3_reranker"): + # For the HF reference, use the pre-converted Sequence Classification + # model to simplify the runner logic. + model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + hf_runner_kwargs = { + "dtype": DTYPE, + "is_cross_encoder": True, + "trust_remote_code": True, + } + elif model["is_cross_encoder"]: + hf_runner_kwargs = {"dtype": DTYPE, "is_cross_encoder": True} + else: + hf_runner_kwargs = {"dtype": DTYPE, "is_sentence_transformer": True} - with hf_runner(model["name"], **kwargs) as hf_model: + with hf_runner(model_name, **hf_runner_kwargs) as hf_model: yield hf_model @@ -191,3 +245,75 @@ def test_score_max_model_len(self, server: RemoteOpenAIServer, assert score_response.status_code == 400 assert "Please, select a smaller truncation size." in \ score_response.text + + def test_rerank_with_template(self, server: RemoteOpenAIServer, + model: dict[str, Any], runner): + if not model.get("is_qwen3_reranker"): + pytest.skip("Test only for Qwen3 Reranker with template support.") + + instruction = "Find the document that is most relevant to the query about national capitals." + query = "What is the capital of China?" + documents = [ + "The capital of France is Paris.", + "The capital of China is Beijing." + ] + + # vLLM run with custom instruction via kwargs + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model["name"], + "query": query, + "documents": documents, + "score_template_kwargs": { + "instruction": instruction + } + }) + rerank_response.raise_for_status() + response_data = RerankResponse.model_validate(rerank_response.json()) + vllm_outputs = { + res.document.text: res.relevance_score + for res in response_data.results + } + + # HF reference run with the same custom instruction + text_pairs = [[query, doc] for doc in documents] + hf_outputs = _run_qwen3_reranker_hf(runner, text_pairs, instruction) + + for i, doc in enumerate(documents): + assert vllm_outputs[doc] == pytest.approx(hf_outputs[i], + rel=0.01) + + def test_score_with_template(self, server: RemoteOpenAIServer, + model: dict[str, Any], runner): + if not model.get("is_qwen3_reranker"): + pytest.skip("Test only for Qwen3 Reranker with template support.") + + instruction = "Find the document that is most relevant to the query about national capitals." + text_1 = "What is the capital of China?" + text_2 = [ + "The capital of France is Paris.", + "The capital of China is Beijing." + ] + + # vLLM run with custom instruction via kwargs + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "score_template_kwargs": { + "instruction": instruction + } + }) + score_response.raise_for_status() + response_data = ScoreResponse.model_validate(score_response.json()) + vllm_outputs = [res.score for res in response_data.data] + + # HF reference run with the same custom instruction + text_pairs = [[text_1, doc] for doc in text_2] + hf_outputs = _run_qwen3_reranker_hf(runner, text_pairs, instruction) + + for i in range(len(vllm_outputs)): + assert vllm_outputs[i] == pytest.approx(hf_outputs[i], rel=0.01) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b5281962b2d..357d731f8a13 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1194,7 +1194,16 @@ class ScoreRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) - + score_template: Optional[dict[str, str]] = Field( + default=None, + description=("A dictionary containing query_template and " + "document_template to format the scorer input.")) + score_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the score template."), + ) # --8<-- [end:score-extra-params] def to_pooling_params(self): @@ -1220,7 +1229,16 @@ class RerankRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) - + rerank_template: Optional[dict[str, str]] = Field( + default=None, + description=("A dictionary containing query_template and " + "document_template to format the reranker input.") + ) + rerank_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("A dictionary of key-value pairs to be formatted into " + "the rerank model's template.") + ) # --8<-- [end:rerank-extra-params] def to_pooling_params(self): diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 9f333c02ab52..502b82e18f0b 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -54,7 +54,7 @@ async def _embedding_score( texts_1: list[str], texts_2: list[str], request: Union[RerankRequest, ScoreRequest], - request_id=str, + request_id: str, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[Union[LoRARequest, None]] = None, prompt_adapter_request: Optional[Union[PromptAdapterRequest, @@ -139,31 +139,58 @@ async def _embedding_score( return final_res_batch - async def _cross_encoding_score( + async def _preprocess_score( self, - tokenizer: Union[AnyTokenizer], + request: Union[RerankRequest, ScoreRequest], + tokenizer: AnyTokenizer, texts_1: list[str], texts_2: list[str], - request: Union[RerankRequest, ScoreRequest], - request_id=str, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[Union[LoRARequest, None]] = None, - prompt_adapter_request: Optional[Union[PromptAdapterRequest, - None]] = None, - trace_headers: Optional[Mapping[str, str]] = None, - ) -> list[PoolingRequestOutput]: - + ) -> tuple[list[str], list[TokensPrompt]]: request_prompts: list[str] = [] engine_prompts: list[TokensPrompt] = [] if len(texts_1) == 1: texts_1 = texts_1 * len(texts_2) - input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)] + def identity_processor(t1: str, t2: str) -> tuple[str, str]: + return t1, t2 - if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") + pair_processor = identity_processor + + template_config = (request.score_template + or self.model_config.hf_config.get( + "score_template")) + + if isinstance(template_config, dict): + + def template_processor(t1: str, t2: str) -> tuple[str, str]: + default_context = template_config.get("default_context", {}) + context = default_context.copy() if isinstance( + default_context, dict) else {} + if request.score_template_kwargs: + context.update(request.score_template_kwargs) + + context['query'] = t1 + context['document'] = t2 + + query_template = template_config.get("query_template", + "{query}") + doc_template = template_config.get("document_template", + "{document}") + + formatted_t1 = query_template.format( + **context) if "query_template" in template_config else t1 + formatted_t2 = doc_template.format( + **context + ) if "document_template" in template_config else t2 + return formatted_t1, formatted_t2 + + pair_processor = template_processor + + input_pairs = [ + pair_processor(t1, t2) for t1, t2 in zip(texts_1, texts_2) + ] tokenize_async = make_async(tokenizer.__call__, executor=self._tokenizer_executor) @@ -186,6 +213,28 @@ async def _cross_encoding_score( request_prompts.append(request_prompt) engine_prompts.append(engine_prompt) + return request_prompts, engine_prompts + + async def _cross_encoding_score( + self, + tokenizer: AnyTokenizer, + texts_1: list[str], + texts_2: list[str], + request: Union[RerankRequest, ScoreRequest], + request_id: str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> list[PoolingRequestOutput]: + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + request_prompts, engine_prompts = await self._preprocess_score( + request, tokenizer, texts_1, texts_2, tokenization_kwargs) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []