Skip to content

Commit 2c9b4a0

Browse files
noooopChen-zexi
authored andcommitted
[Model][3/N] Automatic conversion of CrossEncoding model (vllm-project#20168)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 532ae1f commit 2c9b4a0

File tree

8 files changed

+234
-133
lines changed

8 files changed

+234
-133
lines changed

docs/models/supported_models.md

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,19 +477,28 @@ If your model is not in the above list, we will try to automatically convert the
477477

478478
Specified using `--task score`.
479479

480-
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
481-
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
482-
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
483-
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
484-
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
485-
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
480+
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
481+
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------|
482+
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
483+
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
484+
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
485+
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
486+
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
487+
488+
!!! note
489+
Load the official original `mxbai-rerank-v2` by using the following command.
490+
491+
```bash
492+
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'
493+
```
486494

487495
!!! note
488496
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
489497

490498
```bash
491499
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
492500
```
501+
493502
[](){ #supported-mm-models }
494503

495504
## List of Multimodal Language Models

tests/models/language/pooling/test_embedding.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import os
4+
from typing import Optional
45

56
import pytest
67

@@ -74,6 +75,13 @@ def test_models(
7475
vllm_extra_kwargs["override_pooler_config"] = \
7576
PoolerConfig(pooling_type="MEAN", normalize=False)
7677

78+
max_model_len: Optional[int] = 512
79+
if model in [
80+
"sentence-transformers/all-MiniLM-L12-v2",
81+
"sentence-transformers/stsb-roberta-base-v2"
82+
]:
83+
max_model_len = None
84+
7785
# The example_prompts has ending "\n", for example:
7886
# "Write a short story about a robot that dreams for the first time.\n"
7987
# sentence_transformers will strip the input texts, see:
@@ -87,7 +95,7 @@ def test_models(
8795

8896
with vllm_runner(model,
8997
task="embed",
90-
max_model_len=512,
98+
max_model_len=max_model_len,
9199
**vllm_extra_kwargs) as vllm_model:
92100
vllm_outputs = vllm_model.embed(example_prompts)
93101

tests/models/language/pooling/test_gte.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,16 @@
5656
enable_test=False),
5757
]
5858

59+
V1FlashAttentionImpNotSupported = [
60+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base"
61+
]
62+
5963

6064
@pytest.mark.parametrize("model_info", MODELS)
61-
def test_embed_models_mteb(hf_runner, vllm_runner,
62-
model_info: EmbedModelInfo) -> None:
65+
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo,
66+
monkeypatch) -> None:
67+
if model_info.name in V1FlashAttentionImpNotSupported:
68+
monkeypatch.setenv("VLLM_USE_V1", "0")
6369

6470
vllm_extra_kwargs: dict[str, Any] = {}
6571
if model_info.architecture == "GteNewModel":
@@ -71,8 +77,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
7177

7278
@pytest.mark.parametrize("model_info", MODELS)
7379
def test_embed_models_correctness(hf_runner, vllm_runner,
74-
model_info: EmbedModelInfo,
75-
example_prompts) -> None:
80+
model_info: EmbedModelInfo, example_prompts,
81+
monkeypatch) -> None:
82+
if model_info.name in V1FlashAttentionImpNotSupported:
83+
monkeypatch.setenv("VLLM_USE_V1", "0")
7684

7785
vllm_extra_kwargs: dict[str, Any] = {}
7886
if model_info.architecture == "GteNewModel":
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
import pytest
6+
import torch
7+
8+
from tests.conftest import HfRunner
9+
10+
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
11+
12+
RERANK_MODELS = [
13+
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
14+
architecture="Qwen2ForSequenceClassification",
15+
dtype="float32",
16+
enable_test=True),
17+
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
18+
architecture="Qwen2ForSequenceClassification",
19+
dtype="float32",
20+
enable_test=False)
21+
]
22+
23+
24+
class MxbaiRerankerHfRunner(HfRunner):
25+
26+
def __init__(self,
27+
model_name: str,
28+
dtype: str = "auto",
29+
*args: Any,
30+
**kwargs: Any) -> None:
31+
from transformers import AutoModelForCausalLM, AutoTokenizer
32+
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
33+
34+
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
35+
padding_side='left')
36+
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
37+
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
38+
39+
def predict(self, prompts: list[list[str]], *args,
40+
**kwargs) -> torch.Tensor:
41+
42+
def process_inputs(pairs):
43+
inputs = self.tokenizer(pairs,
44+
padding=False,
45+
truncation='longest_first',
46+
return_attention_mask=False)
47+
for i, ele in enumerate(inputs['input_ids']):
48+
inputs['input_ids'][i] = ele
49+
inputs = self.tokenizer.pad(inputs,
50+
padding=True,
51+
return_tensors="pt")
52+
for key in inputs:
53+
inputs[key] = inputs[key].to(self.model.device)
54+
return inputs
55+
56+
@torch.no_grad()
57+
def compute_logits(inputs):
58+
logits = self.model(**inputs).logits[:, -1, :]
59+
yes_logits = logits[:, self.yes_loc]
60+
no_logits = logits[:, self.no_loc]
61+
logits = yes_logits - no_logits
62+
scores = logits.float().sigmoid()
63+
return scores
64+
65+
scores = []
66+
for prompt in prompts:
67+
inputs = process_inputs([prompt])
68+
score = compute_logits(inputs)
69+
scores.append(score[0].item())
70+
return torch.Tensor(scores)
71+
72+
73+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
74+
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
75+
vllm_extra_kwargs: dict[str, Any] = {}
76+
if model_info.architecture == "Qwen2ForSequenceClassification":
77+
vllm_extra_kwargs["hf_overrides"] = {
78+
"architectures": ["Qwen2ForSequenceClassification"],
79+
"classifier_from_token": ["0", "1"],
80+
"method": "from_2_way_softmax",
81+
}
82+
83+
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
84+
vllm_extra_kwargs)

vllm/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def __post_init__(self) -> None:
466466
"affect the random state of the Python process that "
467467
"launched vLLM.", self.seed)
468468

469+
# Keep set served_model_name before maybe_model_redirect(self.model)
470+
self.served_model_name = get_served_model_name(self.model,
471+
self.served_model_name)
469472
self.model = maybe_model_redirect(self.model)
470473
# The tokenizer is consistent with the model by default.
471474
if self.tokenizer is None:
@@ -609,8 +612,6 @@ def __post_init__(self) -> None:
609612

610613
self.original_max_model_len = self.max_model_len
611614
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
612-
self.served_model_name = get_served_model_name(self.model,
613-
self.served_model_name)
614615
self.multimodal_config = self._init_multimodal_config()
615616
if not self.skip_tokenizer_init:
616617
self._verify_tokenizer_mode()
@@ -1420,7 +1421,7 @@ def is_multimodal_model(self) -> bool:
14201421

14211422
@property
14221423
def is_cross_encoder(self) -> bool:
1423-
return self.registry.is_cross_encoder_model(self.architectures)
1424+
return self.task == "classify"
14241425

14251426
@property
14261427
def use_mla(self) -> bool:
@@ -4762,6 +4763,12 @@ def try_verify_and_update_config(self):
47624763
if cls is not None:
47634764
cls.verify_and_update_config(self)
47644765

4766+
if self.model_config.task == "classify":
4767+
# Maybe convert ForCausalLM into ForSequenceClassification model.
4768+
from vllm.model_executor.models.adapters import (
4769+
SequenceClassificationConfig)
4770+
SequenceClassificationConfig.verify_and_update_config(self)
4771+
47654772
def __str__(self):
47664773
return (
47674774
f"model={self.model_config.model!r},"

vllm/model_executor/models/adapters.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5-
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
5+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
66

77
import torch
88
import torch.nn as nn
99

10+
from vllm.model_executor.models.config import VerifyAndUpdateConfig
11+
1012
from .interfaces_base import VllmModelForPooling, is_pooling_model
1113

1214
if TYPE_CHECKING:
15+
from vllm.config import VllmConfig
1316
from vllm.model_executor.layers.pooler import PoolingType
1417

1518
_T = TypeVar("_T", bound=type[nn.Module])
@@ -39,7 +42,6 @@ def _create_pooling_model_cls(
3942
default_softmax: bool,
4043
) -> _T:
4144
# Lazy import
42-
from vllm.config import VllmConfig
4345
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
4446
from vllm.model_executor.pooling_metadata import PoolingMetadata
4547

@@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
162164
return cls
163165

164166
# Lazy import
165-
from vllm.config import VllmConfig
166167
from vllm.model_executor.layers.linear import RowParallelLinear
167168
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
168169
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
@@ -193,6 +194,7 @@ def __init__(
193194
config = vllm_config.model_config.hf_config
194195
quant_config = vllm_config.quant_config
195196

197+
self.vllm_config = vllm_config
196198
self.task = vllm_config.model_config.task
197199
self.pooling_type = (
198200
vllm_config.model_config.pooler_config.pooling_type)
@@ -242,6 +244,17 @@ def get_logits(hidden_states):
242244
]
243245
return PoolerOutput(outputs=pooled_outputs)
244246

247+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
248+
tokens = getattr(self.config, "classifier_from_token", None)
249+
method = getattr(self.config, "method", None)
250+
251+
if tokens is None and method is None:
252+
return super().load_weights(weights)
253+
else:
254+
# Online convert ForCausalLM into
255+
# ForSequenceClassification model.
256+
return seq_cls_model_loader(self, weights)
257+
245258

246259
ModelForSequenceClassification.__name__ = \
247260
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
@@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
277290
_get_pooling_model_name(cls.__name__, "ForReward")
278291

279292
return ModelForReward # type: ignore
293+
294+
295+
class SequenceClassificationConfig(VerifyAndUpdateConfig):
296+
297+
@staticmethod
298+
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
299+
config = vllm_config.model_config.hf_config
300+
method = getattr(config, "method", None)
301+
tokens = getattr(config, "classifier_from_token", None)
302+
303+
if method is None:
304+
return
305+
306+
assert tokens is not None
307+
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
308+
309+
if method == "from_2_way_softmax":
310+
assert len(tokens) == 2
311+
config.num_labels = 1
312+
else:
313+
config.num_labels = len(tokens)
314+
315+
316+
def load_weights_using_from_2_way_softmax(
317+
model, weights: Iterable[tuple[str, torch.Tensor]]):
318+
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
319+
from vllm.model_executor.layers.vocab_parallel_embedding import (
320+
ParallelLMHead)
321+
from vllm.model_executor.models.utils import AutoWeightsLoader
322+
323+
model_config = model.vllm_config.model_config
324+
tokens = getattr(model.config, "classifier_from_token", [])
325+
tokens = cast(list[int], tokens)
326+
assert len(tokens) == 2
327+
328+
device = model.score.weight.device
329+
330+
if model.config.tie_word_embeddings:
331+
model.lm_head = model.model.embed_tokens
332+
else:
333+
model.lm_head = ParallelLMHead(model.config.vocab_size,
334+
model.config.hidden_size,
335+
quant_config=model.quant_config)
336+
337+
loader = AutoWeightsLoader(model)
338+
loaded_weights = loader.load_weights(weights)
339+
340+
from vllm.transformers_utils.tokenizer import get_tokenizer
341+
tokenizer = get_tokenizer(model_config.tokenizer,
342+
revision=model_config.tokenizer_revision,
343+
tokenizer_mode=model_config.tokenizer_mode,
344+
trust_remote_code=model_config.trust_remote_code)
345+
346+
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
347+
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
348+
weight = model.lm_head.weight.data[true_id].to(device).to(
349+
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
350+
torch.float32)
351+
model.score.weight.data.copy_(weight)
352+
353+
del model.lm_head
354+
loaded_weights.add("score.weight")
355+
loaded_weights.discard("lm_head.weight")
356+
return loaded_weights
357+
358+
359+
SEQ_CLS_LOAD_METHODS = {
360+
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
361+
}
362+
363+
364+
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
365+
# Online convert ForCausalLM into ForSequenceClassification model.
366+
# - from_2_way_softmax:
367+
# - Qwen3ForCausalLM
368+
# - Qwen3-Reranker
369+
# - Qwen2ForCausalLM
370+
# - mxbai-rerank-v2
371+
372+
config = model.vllm_config.model_config.hf_config
373+
method = getattr(config, "method", None)
374+
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
375+
return SEQ_CLS_LOAD_METHODS[method](model, weights)

vllm/model_executor/models/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
167167
assert tokens is not None and len(tokens) == 2, \
168168
("Try loading the original Qwen3 Reranker?, see: "
169169
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
170-
config.num_labels = 1
170+
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
171171

172172

173173
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

0 commit comments

Comments
 (0)