diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 422c406d5f31..f427968c8258 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -481,11 +481,19 @@ Specified using `--task score`. | Architecture | Models | Example HF Models | [V1](gh-issue:8779) | |--------------|--------|-------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | +| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | +!!! note + Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command. + + ```bash + vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' + ``` + !!! note Load the official original `mxbai-rerank-v2` by using the following command. diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/convert_model_to_seq_cls.py new file mode 100644 index 000000000000..72356020330f --- /dev/null +++ b/examples/offline_inference/convert_model_to_seq_cls.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import argparse +import json + +import torch +import transformers + +# Usage: +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls + + +def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): + # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 + assert len(tokens) == 2 + + lm_head_weights = causal_lm.lm_head.weight + + false_id = tokenizer.convert_tokens_to_ids(tokens[0]) + true_id = tokenizer.convert_tokens_to_ids(tokens[1]) + + score_weight = lm_head_weights[true_id].to(device).to( + torch.float32 + ) - lm_head_weights[false_id].to(device).to(torch.float32) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device): + lm_head_weights = causal_lm.lm_head.weight + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + + score_weight = lm_head_weights[token_ids].to(device) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +method_map = { + function.__name__: function for function in [from_2_way_softmax, no_post_processing] +} + + +def converting( + model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" +): + assert method in method_map + + if method == "from_2_way_softmax": + assert len(classifier_from_tokens) == 2 + num_labels = 1 + else: + num_labels = len(classifier_from_tokens) + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + causal_lm = transformers.AutoModelForCausalLM.from_pretrained( + model_name, device_map=device + ) + + seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=num_labels, + ignore_mismatched_sizes=True, + device_map=device, + ) + + method_map[method]( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device + ) + + # `llm as reranker` defaults to not using pad_token + seq_cls_model.config.use_pad_token = use_pad_token + seq_cls_model.config.pad_token_id = tokenizer.pad_token_id + + seq_cls_model.save_pretrained(path) + tokenizer.save_pretrained(path) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Converting *ForCausalLM models to " + "*ForSequenceClassification models." + ) + parser.add_argument( + "--model_name", + type=str, + default="BAAI/bge-reranker-v2-gemma", + help="Model name", + ) + parser.add_argument( + "--classifier_from_tokens", + type=str, + default='["Yes"]', + help="classifier from tokens", + ) + parser.add_argument( + "--method", type=str, default="no_post_processing", help="Converting converting" + ) + parser.add_argument( + "--use-pad-token", action="store_true", help="Whether to use pad_token" + ) + parser.add_argument( + "--path", + type=str, + default="./bge-reranker-v2-gemma-seq-cls", + help="Path to save converted model", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + converting( + model_name=args.model_name, + classifier_from_tokens=json.loads(args.classifier_from_tokens), + method=args.method, + use_pad_token=args.use_pad_token, + path=args.path, + ) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index a83d25818584..59336c1f7906 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -267,7 +267,8 @@ def mteb_test_rerank_models(hf_runner, vllm_runner, model_info: RerankModelInfo, vllm_extra_kwargs=None, - hf_model_callback=None): + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -288,7 +289,7 @@ def mteb_test_rerank_models(hf_runner, assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 - vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), + vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py new file mode 100644 index 000000000000..7fa9485dbc7f --- /dev/null +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import numpy as np +import pytest +import torch + +from tests.conftest import HfRunner + +from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, + mteb_test_rerank_models) + +RERANK_MODELS = [ + RerankModelInfo("BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification"), +] + +PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 + + +class GemmaRerankerHfRunner(HfRunner): + + def __init__(self, + model_name: str, + dtype: str = "auto", + *args: Any, + **kwargs: Any) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, + padding_side='left') + self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") + + @torch.no_grad() + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + + def get_inputs(pairs, tokenizer, prompt=None): + if prompt is None: + prompt = PROMPT + + sep = "\n" + prompt_inputs = tokenizer(prompt, + return_tensors=None, + add_special_tokens=False)["input_ids"] + sep_inputs = tokenizer(sep, + return_tensors=None, + add_special_tokens=False)["input_ids"] + inputs = [] + for query, passage in pairs: + query_inputs = tokenizer( + f"A: {query}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + passage_inputs = tokenizer( + f"B: {passage}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + item = tokenizer.prepare_for_model( + [tokenizer.bos_token_id] + query_inputs["input_ids"], + sep_inputs + passage_inputs["input_ids"], + truncation="only_second", + padding=False, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=False, + ) + item["input_ids"] = item[ + "input_ids"] + sep_inputs + prompt_inputs + item["attention_mask"] = [1] * len(item["input_ids"]) + inputs.append(item) + return tokenizer.pad( + inputs, + padding=True, + return_tensors="pt", + ) + + scores = [] + for query, doc, *_ in prompts: + pairs = [(query, doc)] + inputs = get_inputs(pairs, self.tokenizer) + inputs = inputs.to(self.model.device) + _n_tokens = inputs["input_ids"].shape[1] + logits = self.model(**inputs, return_dict=True).logits + _scores = (logits[:, -1, + self.yes_loc].view(-1, ).float().sigmoid()) + scores.append(_scores[0].item()) + return torch.Tensor(scores) + + +class GemmaMtebEncoder(VllmMtebEncoder): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt = PROMPT + self.query_template = "A: {query}\n" + self.document_template = "B: {doc}\n{prompt}" + + def predict( + self, + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + + _sentences = [] + for query, corpus, prompt in sentences: + query = self.query_template.format(query=query) + corpus = self.document_template.format(doc=corpus, prompt=prompt) + _sentences.append((query, corpus, prompt)) + + return super().predict(_sentences, *args, **kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, + monkeypatch) -> None: + monkeypatch.setenv("VLLM_USE_V1", "0") + + assert model_info.architecture == "GemmaForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + } + } + + mteb_test_rerank_models(GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index a1293a95bfd5..e74c58744dd2 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -12,11 +12,9 @@ RERANK_MODELS = [ RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=True), RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=False) ] diff --git a/tests/models/registry.py b/tests/models/registry.py index aba01cefe993..48302f9d6648 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -319,9 +319,14 @@ def check_available_online( _CROSS_ENCODER_EXAMPLE_MODELS = { # [Text-only] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { diff --git a/vllm/config.py b/vllm/config.py index 724f69a3887f..b7ba434db917 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1449,6 +1449,12 @@ def is_matryoshka(self) -> bool: def matryoshka_dimensions(self): return getattr(self.hf_config, "matryoshka_dimensions", None) + @property + def use_pad_token(self) -> bool: + # cross_encoder models defaults to using pad_token. + # `llm as reranker` models defaults to not using pad_token. + return getattr(self.hf_config, "use_pad_token", True) + def get_and_verify_max_len(self, max_model_len: int): # For pooling models, the tokenizer's `model_max_length` is often a # reliable source for the maximum sequence length. However, for diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6357c2a37c8f..16c051d61de3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1205,7 +1205,6 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] pooling_params = PoolingParams(use_cross_encoder=True) - tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -1213,9 +1212,14 @@ def _cross_encoding_score( parsed_prompts = [] for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + if self.llm_engine.model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + else: + # `llm as reranker` models defaults to not using pad_token. + prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 328d4ff0e6c0..8b2e3e507c4d 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -167,12 +167,22 @@ async def _cross_encoding_score( executor=self._tokenizer_executor) tokenization_kwargs = tokenization_kwargs or {} - tokenized_prompts = await asyncio.gather( - *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) - for t1, t2 in input_pairs)) + use_pad_token = self.model_config.use_pad_token + + if use_pad_token: + # cross_encoder models defaults to using pad_token. + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) + else: + # `llm as reranker` models defaults to not using pad_token. + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1 + t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if tokenizer.sep_token else '' + sep_token = tokenizer.sep_token if (tokenizer.sep_token + and use_pad_token) else '' request_prompt = f"{t1}{sep_token}{t2}" input_ids = prompt_inputs["input_ids"] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 78d86f6f2044..6584c84436c2 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -312,6 +312,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: else: config.num_labels = len(tokens) + # `llm as reranker` defaults to not using pad_token + use_pad_token = getattr(config, "use_pad_token", False) + config.use_pad_token = use_pad_token + def load_weights_using_from_2_way_softmax( model, weights: Iterable[tuple[str, torch.Tensor]]): @@ -356,8 +360,49 @@ def load_weights_using_from_2_way_softmax( return loaded_weights +def load_weights_no_post_processing(model, + weights: Iterable[tuple[str, + torch.Tensor]]): + from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) + from vllm.model_executor.models.utils import AutoWeightsLoader + + model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) + tokens = cast(list[int], tokens) + assert len(tokens) > 0 + + device = model.score.weight.device + + if model.config.tie_word_embeddings: + model.lm_head = model.model.embed_tokens + else: + model.lm_head = ParallelLMHead(model.config.vocab_size, + model.config.hidden_size, + quant_config=model.quant_config) + + loader = AutoWeightsLoader(model) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer(model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + score_weight = model.lm_head.weight.data[token_ids].to(device) + model.score.weight.data.copy_(score_weight) + + del model.lm_head + loaded_weights.add("score.weight") + loaded_weights.discard("lm_head.weight") + return loaded_weights + + SEQ_CLS_LOAD_METHODS = { "from_2_way_softmax": load_weights_using_from_2_way_softmax, + "no_post_processing": load_weights_no_post_processing, } @@ -368,6 +413,9 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): # - Qwen3-Reranker # - Qwen2ForCausalLM # - mxbai-rerank-v2 + # - no_post_processing: + # - GemmaForCausalLM + # - bge-reranker-v2-gemma config = model.vllm_config.model_config.hf_config method = getattr(config, "method", None) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add4c..bc8179f886fd 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,6 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -425,3 +426,6 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e377..27d476929855 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -179,8 +179,9 @@ "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] + "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 + "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 } _MULTIMODAL_MODELS = {