Skip to content

[Model][3/N] Automatic conversion of CrossEncoding model #20168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,28 @@ If your model is not in the above list, we will try to automatically convert the

Specified using `--task score`.

| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |

!!! note
Load the official original `mxbai-rerank-v2` by using the following command.

```bash
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'
```

!!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.

```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
```

[](){ #supported-mm-models }

## List of Multimodal Language Models
Expand Down
10 changes: 9 additions & 1 deletion tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional

import pytest

Expand Down Expand Up @@ -74,6 +75,13 @@ def test_models(
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN", normalize=False)

max_model_len: Optional[int] = 512
if model in [
"sentence-transformers/all-MiniLM-L12-v2",
"sentence-transformers/stsb-roberta-base-v2"
]:
max_model_len = None

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
Expand All @@ -87,7 +95,7 @@ def test_models(

with vllm_runner(model,
task="embed",
max_model_len=512,
max_model_len=max_model_len,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.embed(example_prompts)

Expand Down
16 changes: 12 additions & 4 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@
enable_test=False),
]

V1FlashAttentionImpNotSupported = [
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base"
]


@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo,
monkeypatch) -> None:
if model_info.name in V1FlashAttentionImpNotSupported:
monkeypatch.setenv("VLLM_USE_V1", "0")

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
Expand All @@ -71,8 +77,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner,

@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo,
example_prompts) -> None:
model_info: EmbedModelInfo, example_prompts,
monkeypatch) -> None:
if model_info.name in V1FlashAttentionImpNotSupported:
monkeypatch.setenv("VLLM_USE_V1", "0")

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
Expand Down
84 changes: 84 additions & 0 deletions tests/models/language/pooling/test_mxbai_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import pytest
import torch

from tests.conftest import HfRunner

from .mteb_utils import RerankModelInfo, mteb_test_rerank_models

RERANK_MODELS = [
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=True),
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=False)
]


class MxbaiRerankerHfRunner(HfRunner):

def __init__(self,
model_name: str,
dtype: str = "auto",
*args: Any,
**kwargs: Any) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)

self.tokenizer = AutoTokenizer.from_pretrained(model_name,
padding_side='left')
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")

def predict(self, prompts: list[list[str]], *args,
**kwargs) -> torch.Tensor:

def process_inputs(pairs):
inputs = self.tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = self.tokenizer.pad(inputs,
padding=True,
return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs):
logits = self.model(**inputs).logits[:, -1, :]
yes_logits = logits[:, self.yes_loc]
no_logits = logits[:, self.no_loc]
logits = yes_logits - no_logits
scores = logits.float().sigmoid()
return scores

scores = []
for prompt in prompts:
inputs = process_inputs([prompt])
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)


@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}

mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
13 changes: 10 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def __post_init__(self) -> None:
"affect the random state of the Python process that "
"launched vLLM.", self.seed)

# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
Expand Down Expand Up @@ -609,8 +612,6 @@ def __post_init__(self) -> None:

self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
Expand Down Expand Up @@ -1420,7 +1421,7 @@ def is_multimodal_model(self) -> bool:

@property
def is_cross_encoder(self) -> bool:
return self.registry.is_cross_encoder_model(self.architectures)
return self.task == "classify"

@property
def use_mla(self) -> bool:
Expand Down Expand Up @@ -4762,6 +4763,12 @@ def try_verify_and_update_config(self):
if cls is not None:
cls.verify_and_update_config(self)

if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
from vllm.model_executor.models.adapters import (
SequenceClassificationConfig)
SequenceClassificationConfig.verify_and_update_config(self)

def __str__(self):
return (
f"model={self.model_config.model!r},"
Expand Down
102 changes: 99 additions & 3 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast

import torch
import torch.nn as nn

from vllm.model_executor.models.config import VerifyAndUpdateConfig

from .interfaces_base import VllmModelForPooling, is_pooling_model

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType

_T = TypeVar("_T", bound=type[nn.Module])
Expand Down Expand Up @@ -39,7 +42,6 @@ def _create_pooling_model_cls(
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata

Expand Down Expand Up @@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return cls

# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
Expand Down Expand Up @@ -193,6 +194,7 @@ def __init__(
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
Expand Down Expand Up @@ -242,6 +244,17 @@ def get_logits(hidden_states):
]
return PoolerOutput(outputs=pooled_outputs)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)

if tokens is None and method is None:
return super().load_weights(weights)
else:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return seq_cls_model_loader(self, weights)


ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
Expand Down Expand Up @@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
_get_pooling_model_name(cls.__name__, "ForReward")

return ModelForReward # type: ignore


class SequenceClassificationConfig(VerifyAndUpdateConfig):

@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)

if method is None:
return

assert tokens is not None
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
else:
config.num_labels = len(tokens)


def load_weights_using_from_2_way_softmax(
model, weights: Iterable[tuple[str, torch.Tensor]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.models.utils import AutoWeightsLoader

model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2

device = model.score.weight.device

if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=model.quant_config)

loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)

from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)

false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
torch.float32)
model.score.weight.data.copy_(weight)

del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights


SEQ_CLS_LOAD_METHODS = {
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2

config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
config.num_labels = 1
vllm_config.model_config.hf_config.method = "from_2_way_softmax"


class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
Expand Down
Loading