Skip to content

[Model][Last/4] Automatic conversion of CrossEncoding model #19675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,19 @@ Specified using `--task score`.
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|--------------|--------|-------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |

!!! note
Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command.

```bash
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
```

!!! note
Load the official original `mxbai-rerank-v2` by using the following command.

Expand Down
134 changes: 134 additions & 0 deletions examples/offline_inference/convert_model_to_seq_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

import argparse
import json

import torch
import transformers

# Usage:
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls


def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
assert len(tokens) == 2

lm_head_weights = causal_lm.lm_head.weight

false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])

score_weight = lm_head_weights[true_id].to(device).to(
torch.float32
) - lm_head_weights[false_id].to(device).to(torch.float32)

with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()


def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
lm_head_weights = causal_lm.lm_head.weight

token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]

score_weight = lm_head_weights[token_ids].to(device)

with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight)
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()


method_map = {
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
}


def converting(
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
):
assert method in method_map

if method == "from_2_way_softmax":
assert len(classifier_from_tokens) == 2
num_labels = 1
else:
num_labels = len(classifier_from_tokens)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
model_name, device_map=device
)

seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
ignore_mismatched_sizes=True,
device_map=device,
)

method_map[method](
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
)

# `llm as reranker` defaults to not using pad_token
seq_cls_model.config.use_pad_token = use_pad_token
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id

seq_cls_model.save_pretrained(path)
tokenizer.save_pretrained(path)


def parse_args():
parser = argparse.ArgumentParser(
description="Converting *ForCausalLM models to "
"*ForSequenceClassification models."
)
parser.add_argument(
"--model_name",
type=str,
default="BAAI/bge-reranker-v2-gemma",
help="Model name",
)
parser.add_argument(
"--classifier_from_tokens",
type=str,
default='["Yes"]',
help="classifier from tokens",
)
parser.add_argument(
"--method", type=str, default="no_post_processing", help="Converting converting"
)
parser.add_argument(
"--use-pad-token", action="store_true", help="Whether to use pad_token"
)
parser.add_argument(
"--path",
type=str,
default="./bge-reranker-v2-gemma-seq-cls",
help="Path to save converted model",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

converting(
model_name=args.model_name,
classifier_from_tokens=json.loads(args.classifier_from_tokens),
method=args.method,
use_pad_token=args.use_pad_token,
path=args.path,
)
5 changes: 3 additions & 2 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def mteb_test_rerank_models(hf_runner,
vllm_runner,
model_info: RerankModelInfo,
vllm_extra_kwargs=None,
hf_model_callback=None):
hf_model_callback=None,
vllm_mteb_encoder=VllmMtebEncoder):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
Expand All @@ -288,7 +289,7 @@ def mteb_test_rerank_models(hf_runner,
assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1

vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype
Expand Down
140 changes: 140 additions & 0 deletions tests/models/language/pooling/test_bge_reranker_v2_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional

import numpy as np
import pytest
import torch

from tests.conftest import HfRunner

from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
mteb_test_rerank_models)

RERANK_MODELS = [
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
]

PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501


class GemmaRerankerHfRunner(HfRunner):

def __init__(self,
model_name: str,
dtype: str = "auto",
*args: Any,
**kwargs: Any) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
padding_side='left')
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")

@torch.no_grad()
def predict(self, prompts: list[list[str]], *args,
**kwargs) -> torch.Tensor:

def get_inputs(pairs, tokenizer, prompt=None):
if prompt is None:
prompt = PROMPT

sep = "\n"
prompt_inputs = tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)["input_ids"]
sep_inputs = tokenizer(sep,
return_tensors=None,
add_special_tokens=False)["input_ids"]
inputs = []
for query, passage in pairs:
query_inputs = tokenizer(
f"A: {query}",
return_tensors=None,
add_special_tokens=False,
truncation=True,
)
passage_inputs = tokenizer(
f"B: {passage}",
return_tensors=None,
add_special_tokens=False,
truncation=True,
)
item = tokenizer.prepare_for_model(
[tokenizer.bos_token_id] + query_inputs["input_ids"],
sep_inputs + passage_inputs["input_ids"],
truncation="only_second",
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False,
)
item["input_ids"] = item[
"input_ids"] + sep_inputs + prompt_inputs
item["attention_mask"] = [1] * len(item["input_ids"])
inputs.append(item)
return tokenizer.pad(
inputs,
padding=True,
return_tensors="pt",
)

scores = []
for query, doc, *_ in prompts:
pairs = [(query, doc)]
inputs = get_inputs(pairs, self.tokenizer)
inputs = inputs.to(self.model.device)
_n_tokens = inputs["input_ids"].shape[1]
logits = self.model(**inputs, return_dict=True).logits
_scores = (logits[:, -1,
self.yes_loc].view(-1, ).float().sigmoid())
scores.append(_scores[0].item())
return torch.Tensor(scores)


class GemmaMtebEncoder(VllmMtebEncoder):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = PROMPT
self.query_template = "A: {query}\n"
self.document_template = "B: {doc}\n{prompt}"

def predict(
self,
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
) -> np.ndarray:

_sentences = []
for query, corpus, prompt in sentences:
query = self.query_template.format(query=query)
corpus = self.document_template.format(doc=corpus, prompt=prompt)
_sentences.append((query, corpus, prompt))

return super().predict(_sentences, *args, **kwargs)


@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")

assert model_info.architecture == "GemmaForSequenceClassification"

vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
}
}

mteb_test_rerank_models(GemmaRerankerHfRunner,
vllm_runner,
model_info,
vllm_extra_kwargs,
vllm_mteb_encoder=GemmaMtebEncoder)
2 changes: 0 additions & 2 deletions tests/models/language/pooling/test_mxbai_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
RERANK_MODELS = [
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=True),
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=False)
]

Expand Down
7 changes: 6 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,14 @@ def check_available_online(
_CROSS_ENCODER_EXAMPLE_MODELS = {
# [Text-only]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
}

_MULTIMODAL_EXAMPLE_MODELS = {
Expand Down
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,12 @@ def is_matryoshka(self) -> bool:
def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None)

@property
def use_pad_token(self) -> bool:
# cross_encoder models defaults to using pad_token.
# `llm as reranker` models defaults to not using pad_token.
return getattr(self.hf_config, "use_pad_token", True)

def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a
# reliable source for the maximum sequence length. However, for
Expand Down
12 changes: 8 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,17 +1205,21 @@ def _cross_encoding_score(
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

pooling_params = PoolingParams(use_cross_encoder=True)

tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)

parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
if self.llm_engine.model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
else:
# `llm as reranker` models defaults to not using pad_token.
prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
Expand Down
Loading