diff --git a/requirements/test.in b/requirements/test.in index 55978fb10d58..c97a8abbe6ad 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -33,7 +33,7 @@ num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test -mteb>=1.38.11, <2 # required for mteb test +mteb[bm25s]>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. diff --git a/requirements/test.txt b/requirements/test.txt index 8cd218d44ed4..8f8edb48bdb0 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -51,6 +51,8 @@ black==24.10.0 # via datamodel-code-generator blobfile==3.0.0 # via -r requirements/test.in +bm25s==0.2.13 + # via mteb boto3==1.35.57 # via tensorizer botocore==1.35.57 @@ -344,6 +346,7 @@ numpy==1.26.4 # -r requirements/test.in # accelerate # bitsandbytes + # bm25s # contourpy # cupy-cuda12x # datasets @@ -534,6 +537,8 @@ pyparsing==3.2.0 # via matplotlib pyrate-limiter==3.7.0 # via schemathesis +pystemmer==3.0.0 + # via mteb pytablewriter==1.2.0 # via lm-eval pytest==8.3.3 @@ -668,6 +673,7 @@ scikit-learn==1.5.2 # sentence-transformers scipy==1.13.1 # via + # bm25s # librosa # mteb # scikit-learn diff --git a/tests/conftest.py b/tests/conftest.py index 5ec3926bd31f..294805a8164f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -727,8 +727,12 @@ def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: return self.model.encode(prompts, *args, **kwargs) - def predict(self, prompts: list[list[str]]) -> torch.Tensor: - return self.model.predict(prompts, convert_to_tensor=True) + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + return self.model.predict(prompts, + *args, + convert_to_tensor=True, + **kwargs) def __enter__(self): return self @@ -1037,8 +1041,10 @@ def score( self, text_1: Union[str, list[str]], text_2: Union[str, list[str]], + *args, + **kwargs, ) -> list[float]: - req_outputs = self.model.score(text_1, text_2) + req_outputs = self.model.score(text_1, text_2, *args, **kwargs) return [req_output.outputs.score for req_output in req_outputs] def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: diff --git a/tests/entrypoints/openai/correctness/test_mteb.py b/tests/entrypoints/openai/correctness/test_mteb_embed.py similarity index 73% rename from tests/entrypoints/openai/correctness/test_mteb.py rename to tests/entrypoints/openai/correctness/test_mteb_embed.py index 437c48511352..12a86f9bdd59 100644 --- a/tests/entrypoints/openai/correctness/test_mteb.py +++ b/tests/entrypoints/openai/correctness/test_mteb_embed.py @@ -7,34 +7,30 @@ from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, MTEB_EMBED_TOL, OpenAIClientMtebEncoder, - run_mteb_embed_task, - run_mteb_embed_task_st) + run_mteb_embed_task) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" -MODEL_NAME = "BAAI/bge-m3" -DTYPE = "float16" -MAIN_SCORE = 0.7873427091972599 +MODEL_NAME = "intfloat/e5-small" +MAIN_SCORE = 0.7422994752439667 @pytest.fixture(scope="module") def server(): args = [ - "--task", "embed", "--dtype", DTYPE, "--enforce-eager", - "--max-model-len", "512" + "--task", "embed", "--enforce-eager", "--disable-uvicorn-access-log" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server -def test_mteb(server): +def test_mteb_embed(server): client = server.get_client() encoder = OpenAIClientMtebEncoder(MODEL_NAME, client) vllm_main_score = run_mteb_embed_task(encoder, MTEB_EMBED_TASKS) - st_main_score = MAIN_SCORE or run_mteb_embed_task_st( - MODEL_NAME, MTEB_EMBED_TASKS) + st_main_score = MAIN_SCORE print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/openai/correctness/test_mteb_score.py new file mode 100644 index 000000000000..f90fc0b9be00 --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_mteb_score.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import pytest + +# yapf conflicts with isort for this block +# yapf: disable +from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS, + MTEB_RERANK_TASKS, + MTEB_RERANK_TOL, + RerankClientMtebEncoder, + ScoreClientMtebEncoder, + run_mteb_rerank) +# yapf: enable +from tests.utils import RemoteOpenAIServer + +os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" + +MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" +MAIN_SCORE = 0.33437 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", "score", "--enforce-eager", "--disable-uvicorn-access-log" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def test_mteb_score(server): + url = server.url_for("score") + encoder = ScoreClientMtebEncoder(MODEL_NAME, url) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, + MTEB_RERANK_LANGS) + st_main_score = MAIN_SCORE + + print("VLLM main score: ", vllm_main_score) + print("SentenceTransformer main score: ", st_main_score) + print("Difference: ", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + + +def test_mteb_rerank(server): + url = server.url_for("rerank") + encoder = RerankClientMtebEncoder(MODEL_NAME, url) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, + MTEB_RERANK_LANGS) + st_main_score = MAIN_SCORE + + print("VLLM main score: ", vllm_main_score) + print("SentenceTransformer main score: ", st_main_score) + print("Difference: ", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 0a047951db44..21d55c418c36 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile from collections.abc import Sequence +from typing import Optional import mteb import numpy as np import pytest +import requests -from tests.models.utils import EmbedModelInfo +from tests.models.utils import EmbedModelInfo, RerankModelInfo -# Most models on the STS12 task (See #17175): +# Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype # results in differences less than 1e-4 # - Different model results in differences more than 1e-3 @@ -16,6 +20,11 @@ MTEB_EMBED_TASKS = ["STS12"] MTEB_EMBED_TOL = 1e-4 +# See #19344 +MTEB_RERANK_TASKS = ["NFCorpus"] +MTEB_RERANK_LANGS = ["en"] +MTEB_RERANK_TOL = 1e-3 + class VllmMtebEncoder(mteb.Encoder): @@ -39,6 +48,27 @@ def encode( embeds = embeds[np.argsort(r)] return embeds + def predict( + self, + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + queries = [s[0] for s in sentences] + corpus = [s[1] for s in sentences] + + outputs = self.model.score(queries, + corpus, + truncate_prompt_tokens=-1, + use_tqdm=False) + scores = np.array(outputs) + scores = scores[np.argsort(r)] + return scores + class OpenAIClientMtebEncoder(mteb.Encoder): @@ -62,21 +92,72 @@ def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: return embeds +class ScoreClientMtebEncoder(mteb.Encoder): + + def __init__(self, model_name: str, url): + super().__init__() + self.model_name = model_name + self.url = url + self.rng = np.random.default_rng(seed=42) + + def predict( + self, + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + outputs = [] + for query, corpus, prompt in sentences: + outputs.append(self.get_score(query, corpus)) + + scores = np.array(outputs) + scores = scores[np.argsort(r)] + return scores + + def get_score(self, query, corpus): + response = requests.post(self.url, + json={ + "model": self.model_name, + "text_1": query, + "text_2": corpus, + "truncate_prompt_tokens": -1, + }).json() + return response['data'][0]["score"] + + +class RerankClientMtebEncoder(ScoreClientMtebEncoder): + + def get_score(self, query, corpus): + response = requests.post(self.url, + json={ + "model": self.model_name, + "query": query, + "documents": [corpus], + "truncate_prompt_tokens": -1, + }).json() + return response['results'][0]["relevance_score"] + + def run_mteb_embed_task(encoder, tasks): tasks = mteb.get_tasks(tasks=tasks) evaluation = mteb.MTEB(tasks=tasks) - results = evaluation.run(encoder, verbosity=0, output_folder=None) + results = evaluation.run( + encoder, + verbosity=0, + output_folder=None, + encode_kwargs={ + "show_progress_bar": False, + }, + ) main_score = results[0].scores["test"][0]["main_score"] return main_score -def run_mteb_embed_task_st(model_name, tasks): - from sentence_transformers import SentenceTransformer - model = SentenceTransformer(model_name) - return run_mteb_embed_task(model, tasks) - - def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo, @@ -118,3 +199,96 @@ def mteb_test_embed_models(hf_runner, print("Difference:", st_main_score - vllm_main_score) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) + + +def run_mteb_rerank(cross_encoder, tasks, languages): + with tempfile.TemporaryDirectory() as results_folder: + bm25s = mteb.get_model("bm25s") + tasks = mteb.get_tasks(tasks=tasks, languages=languages) + + subset = "default" + eval_splits = ["test"] + + evaluation = mteb.MTEB(tasks=tasks) + evaluation.run( + bm25s, + verbosity=0, + eval_splits=eval_splits, + save_predictions=True, + output_folder=f"{results_folder}/stage1", + encode_kwargs={"show_progress_bar": False}, + ) + + results = evaluation.run( + cross_encoder, + verbosity=0, + eval_splits=eval_splits, + top_k=10, + save_predictions=True, + output_folder=f"{results_folder}/stage2", + previous_results= + f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", + encode_kwargs={"show_progress_bar": False}, + ) + main_score = results[0].scores["test"][0]["main_score"] + return main_score + + +def mteb_test_rerank_models(hf_runner, + vllm_runner, + model_info: RerankModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None): + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = model_info.dtype + + with vllm_runner(model_info.name, + task="score", + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + + if model_info.architecture: + assert (model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures) + + vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS) + vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + + with hf_runner(model_info.name, is_cross_encoder=True, + dtype="float32") as hf_model: + + original_predict = hf_model.predict + + def _predict( + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ): + # vllm and st both remove the prompt, fair comparison. + prompts = [(s[0], s[1]) for s in sentences] + return original_predict(prompts, *args, **kwargs, batch_size=8) + + hf_model.predict = _predict + hf_model.original_predict = original_predict + + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_rerank(hf_model, + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS) + st_dtype = next(hf_model.model.model.parameters()).dtype + + print("VLLM:", vllm_dtype, vllm_main_score) + print("SentenceTransformers:", st_dtype, st_main_score) + print("Difference:", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 1af3c05d3d90..3990e8ea92c8 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models +from ...utils import EmbedModelInfo, RerankModelInfo +from .embed_utils import correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel @@ -57,6 +58,20 @@ enable_test=True), ] +RERANK_MODELS = [ + ########## XLMRobertaForSequenceClassification + RerankModelInfo("BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + enable_test=True), + RerankModelInfo("BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False), + RerankModelInfo("BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + dtype="float32", + enable_test=False) +] + @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, @@ -70,3 +85,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner, example_prompts) -> None: correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(hf_runner, vllm_runner, + model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py new file mode 100644 index 000000000000..9a33063d7b46 --- /dev/null +++ b/tests/models/language/pooling/test_cross_encoder.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from .mteb_utils import RerankModelInfo, mteb_test_rerank_models + +RERANK_MODELS = [ + RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", + architecture="BertForSequenceClassification"), + RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + architecture="Qwen3ForSequenceClassification") +] + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(hf_runner, vllm_runner, + model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 33255021ad6a..0c44683e7486 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -6,28 +6,10 @@ from vllm import PoolingParams -from .embed_utils import (EmbedModelInfo, check_embeddings_close, +from ...utils import EmbedModelInfo, RerankModelInfo +from .embed_utils import (check_embeddings_close, correctness_test_embed_models, matryoshka_fy) -from .mteb_utils import mteb_test_embed_models - -SCORING_MODELS = [ - "jinaai/jina-reranker-v2-base-multilingual", # Roberta -] - -TEXTS_1 = ["Organic skincare products for sensitive skin"] - -TEXTS_2 = [ - "Organic skincare for sensitive skin with aloe vera and chamomile.", - "New makeup trends focus on bold colors and innovative techniques", - "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille", - "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken", # noqa: E501 - "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla", # noqa: E501 - "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras", # noqa: E501 - "针对敏感肌专门设计的天然有机护肤产品", - "新的化妆趋势注重鲜艳的颜色和创新的技巧", - "敏感肌のために特別に設計された天然有機スキンケア製品", - "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています", -] +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ EmbedModelInfo("jinaai/jina-embeddings-v3", @@ -35,47 +17,13 @@ is_matryoshka=True) ] - -@pytest.fixture(scope="module", params=SCORING_MODELS) -def model_name(request): - yield request.param - - -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): - - text_pair = [TEXTS_1[0], TEXTS_2[0]] - - with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: - hf_outputs = hf_model.predict([text_pair]).tolist() - - with vllm_runner(model_name, task="score", dtype=dtype, - max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) - - assert len(vllm_outputs) == 1 - assert len(hf_outputs) == 1 - - assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) - - -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): - - text_pairs = [[TEXTS_1[0], text] for text in TEXTS_2] - - with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: - hf_outputs = hf_model.predict(text_pairs).tolist() - - with vllm_runner(model_name, task="score", dtype=dtype, - max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) - - assert len(vllm_outputs) == 10 - assert len(hf_outputs) == 10 - - assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) - assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) +RERANK_MODELS = [ + RerankModelInfo( + "jinaai/jina-reranker-v2-base-multilingual", + architecture="XLMRobertaForSequenceClassification", + dtype="float32", + ) +] @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @@ -106,6 +54,12 @@ def hf_model_callback(model): hf_model_callback=hf_model_callback) +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(hf_runner, vllm_runner, + model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dimensions", [16, 32]) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 63b37d9a077d..b1e8fd6294ca 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -1,87 +1,91 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import pytest +import torch + +from tests.conftest import HfRunner -model_name = "Qwen/Qwen3-Reranker-4B" +from .mteb_utils import RerankModelInfo, mteb_test_rerank_models -text_1 = "What is the capital of France?" -texts_2 = [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", +RERANK_MODELS = [ + RerankModelInfo("Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + dtype="float32", + enable_test=True), + RerankModelInfo("Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + dtype="float32", + enable_test=False) ] -def vllm_reranker(model_name): - from vllm import LLM - - model = LLM(model=model_name, - task="score", - hf_overrides={ - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, - dtype="float32") - - text_1 = "What is the capital of France?" - texts_2 = [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", - ] - - outputs = model.score(text_1, texts_2) - - return [output.outputs.score for output in outputs] - - -def hf_reranker(model_name): - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') - model = AutoModelForCausalLM.from_pretrained(model_name).eval() - - token_false_id = tokenizer.convert_tokens_to_ids("no") - token_true_id = tokenizer.convert_tokens_to_ids("yes") - - max_length = 8192 - - def process_inputs(pairs): - inputs = tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False, - max_length=max_length) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = tokenizer.pad(inputs, - padding=True, - return_tensors="pt", - max_length=max_length) - for key in inputs: - inputs[key] = inputs[key].to(model.device) - return inputs - - @torch.no_grad() - def compute_logits(inputs, **kwargs): - batch_scores = model(**inputs).logits[:, -1, :] - true_vector = batch_scores[:, token_true_id] - false_vector = batch_scores[:, token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1) - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - scores = batch_scores[:, 1].exp().tolist() - return scores - - pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] - inputs = process_inputs(pairs) - scores = compute_logits(inputs) - - return scores - - -@pytest.mark.parametrize("model_name", [model_name]) -def test_model(model_name): - hf_outputs = hf_reranker(model_name) - vllm_outputs = vllm_reranker(model_name) - - assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) - assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) +class Qwen3RerankerHfRunner(HfRunner): + + def __init__(self, + model_name: str, + dtype: str = "auto", + *args: Any, + **kwargs: Any) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, + padding_side='left') + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + + def process_inputs(pairs): + inputs = self.tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = ele + inputs = self.tokenizer.pad(inputs, + padding=True, + return_tensors="pt") + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs): + batch_scores = self.model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp() + return scores + + scores = [] + for prompt in prompts: + inputs = process_inputs([prompt]) + score = compute_logits(inputs) + scores.append(score[0].item()) + return torch.Tensor(scores) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: + + assert model_info.architecture == "Qwen3ForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + } + } + + if model_info.name == "Qwen/Qwen3-Reranker-4B": + vllm_extra_kwargs["max_num_seqs"] = 1 + + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py b/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py deleted file mode 100644 index ee07f6ff9dca..000000000000 --- a/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest - -model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" - -text_1 = "What is the capital of France?" -texts_2 = [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", -] - - -def vllm_reranker(model_name): - from vllm import LLM - - model = LLM(model=model_name, task="score") - outputs = model.score(text_1, texts_2) - - return [output.outputs.score for output in outputs] - - -def hf_reranker(model_name): - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') - model = AutoModelForCausalLM.from_pretrained(model_name).eval() - - token_false_id = tokenizer.convert_tokens_to_ids("no") - token_true_id = tokenizer.convert_tokens_to_ids("yes") - - max_length = 8192 - - def process_inputs(pairs): - inputs = tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False, - max_length=max_length) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = tokenizer.pad(inputs, - padding=True, - return_tensors="pt", - max_length=max_length) - for key in inputs: - inputs[key] = inputs[key].to(model.device) - return inputs - - @torch.no_grad() - def compute_logits(inputs, **kwargs): - batch_scores = model(**inputs).logits[:, -1, :] - true_vector = batch_scores[:, token_true_id] - false_vector = batch_scores[:, token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1) - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - scores = batch_scores[:, 1].exp().tolist() - return scores - - pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] - inputs = process_inputs(pairs) - scores = compute_logits(inputs) - - return scores - - -@pytest.mark.parametrize("model_name", [model_name]) -def test_model(model_name): - hf_outputs = hf_reranker(model_name) - vllm_outputs = vllm_reranker(model_name) - - assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) - assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/models/utils.py b/tests/models/utils.py index 943b4f570446..cdf8d02df73c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -336,3 +336,10 @@ class EmbedModelInfo(NamedTuple): architecture: str = "" dtype: str = "auto" enable_test: bool = True + + +class RerankModelInfo(NamedTuple): + name: str + architecture: str = "" + dtype: str = "auto" + enable_test: bool = True diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 258038bed40b..6829d93d2d6c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -156,7 +156,10 @@ def extract_states( ) -> Union[list[torch.Tensor], torch.Tensor]: prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - cumsum = torch.cumsum(hidden_states, dim=0) + # Use float32 for torch.cumsum in MeanPool, + # otherwise precision will be lost significantly. + cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) + start_indices = torch.cat([ torch.tensor([0], device=hidden_states.device), torch.cumsum(prompt_lens[:-1], dim=0) @@ -220,6 +223,13 @@ def __init__(self, *, normalize: bool, softmax: bool) -> None: def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Using float32 in PoolerHead + if isinstance(pooled_data, list): + for i in range(len(pooled_data)): + pooled_data[i] = pooled_data[i].to(torch.float32) + else: + pooled_data = pooled_data.to(torch.float32) + dimensions_list = [ pooling_param.dimensions for _, pooling_param in pooling_metadata.seq_groups diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index cacec7342ac2..389393987c81 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -414,15 +414,10 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids=input_ids, - position_ids=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - - # convert the embedding output to float32, - # otherwise precision will be lost significantly - hidden_states = hidden_states.to(torch.float32) - return hidden_states + return self.model(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) def pooler( self, diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index d1b84a9f04fa..0f22393c79d9 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -432,12 +432,7 @@ def forward( else: hidden_states = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids) - hidden_states = self.encoder(positions, hidden_states) - - # convert the embedding output to float32, - # otherwise precision will be lost significantly - hidden_states = hidden_states.to(torch.float32) - return hidden_states + return self.encoder(positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: