Skip to content

Commit dcb47a2

Browse files
noooopZebing Lin
authored andcommitted
[Model][Last/4] Automatic conversion of CrossEncoding model (vllm-project#19675)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 1ad69e8 commit dcb47a2

File tree

13 files changed

+376
-23
lines changed

13 files changed

+376
-23
lines changed

docs/models/supported_models.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,19 @@ Specified using `--task score`.
481481
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
482482
|--------------|--------|-------------------|---------------------|
483483
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
484+
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | |
484485
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
485486
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
486487
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
487488
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
488489

490+
!!! note
491+
Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command.
492+
493+
```bash
494+
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
495+
```
496+
489497
!!! note
490498
Load the official original `mxbai-rerank-v2` by using the following command.
491499

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
import argparse
6+
import json
7+
8+
import torch
9+
import transformers
10+
11+
# Usage:
12+
# for BAAI/bge-reranker-v2-gemma
13+
# Caution: "Yes" and "yes" are two different tokens
14+
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
15+
# for mxbai-rerank-v2
16+
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
17+
# for Qwen3-Reranker
18+
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
19+
20+
21+
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
22+
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
23+
assert len(tokens) == 2
24+
25+
lm_head_weights = causal_lm.lm_head.weight
26+
27+
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
28+
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
29+
30+
score_weight = lm_head_weights[true_id].to(device).to(
31+
torch.float32
32+
) - lm_head_weights[false_id].to(device).to(torch.float32)
33+
34+
with torch.no_grad():
35+
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
36+
if seq_cls_model.score.bias is not None:
37+
seq_cls_model.score.bias.zero_()
38+
39+
40+
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
41+
lm_head_weights = causal_lm.lm_head.weight
42+
43+
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
44+
45+
score_weight = lm_head_weights[token_ids].to(device)
46+
47+
with torch.no_grad():
48+
seq_cls_model.score.weight.copy_(score_weight)
49+
if seq_cls_model.score.bias is not None:
50+
seq_cls_model.score.bias.zero_()
51+
52+
53+
method_map = {
54+
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
55+
}
56+
57+
58+
def converting(
59+
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
60+
):
61+
assert method in method_map
62+
63+
if method == "from_2_way_softmax":
64+
assert len(classifier_from_tokens) == 2
65+
num_labels = 1
66+
else:
67+
num_labels = len(classifier_from_tokens)
68+
69+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
70+
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
71+
model_name, device_map=device
72+
)
73+
74+
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
75+
model_name,
76+
num_labels=num_labels,
77+
ignore_mismatched_sizes=True,
78+
device_map=device,
79+
)
80+
81+
method_map[method](
82+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
83+
)
84+
85+
# `llm as reranker` defaults to not using pad_token
86+
seq_cls_model.config.use_pad_token = use_pad_token
87+
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
88+
89+
seq_cls_model.save_pretrained(path)
90+
tokenizer.save_pretrained(path)
91+
92+
93+
def parse_args():
94+
parser = argparse.ArgumentParser(
95+
description="Converting *ForCausalLM models to "
96+
"*ForSequenceClassification models."
97+
)
98+
parser.add_argument(
99+
"--model_name",
100+
type=str,
101+
default="BAAI/bge-reranker-v2-gemma",
102+
help="Model name",
103+
)
104+
parser.add_argument(
105+
"--classifier_from_tokens",
106+
type=str,
107+
default='["Yes"]',
108+
help="classifier from tokens",
109+
)
110+
parser.add_argument(
111+
"--method", type=str, default="no_post_processing", help="Converting converting"
112+
)
113+
parser.add_argument(
114+
"--use-pad-token", action="store_true", help="Whether to use pad_token"
115+
)
116+
parser.add_argument(
117+
"--path",
118+
type=str,
119+
default="./bge-reranker-v2-gemma-seq-cls",
120+
help="Path to save converted model",
121+
)
122+
return parser.parse_args()
123+
124+
125+
if __name__ == "__main__":
126+
args = parse_args()
127+
128+
converting(
129+
model_name=args.model_name,
130+
classifier_from_tokens=json.loads(args.classifier_from_tokens),
131+
method=args.method,
132+
use_pad_token=args.use_pad_token,
133+
path=args.path,
134+
)

tests/models/language/pooling/mteb_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def mteb_test_rerank_models(hf_runner,
267267
vllm_runner,
268268
model_info: RerankModelInfo,
269269
vllm_extra_kwargs=None,
270-
hf_model_callback=None):
270+
hf_model_callback=None,
271+
vllm_mteb_encoder=VllmMtebEncoder):
271272
if not model_info.enable_test:
272273
# A model family has many models with the same architecture,
273274
# and we don't need to test each one.
@@ -288,7 +289,7 @@ def mteb_test_rerank_models(hf_runner,
288289
assert (model_info.architecture in model_config.architectures)
289290
assert model_config.hf_config.num_labels == 1
290291

291-
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
292+
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
292293
tasks=MTEB_RERANK_TASKS,
293294
languages=MTEB_RERANK_LANGS)
294295
vllm_dtype = model_config.dtype
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any, Optional
4+
5+
import numpy as np
6+
import pytest
7+
import torch
8+
9+
from tests.conftest import HfRunner
10+
11+
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
12+
mteb_test_rerank_models)
13+
14+
RERANK_MODELS = [
15+
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
16+
architecture="GemmaForSequenceClassification"),
17+
]
18+
19+
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
20+
21+
22+
class GemmaRerankerHfRunner(HfRunner):
23+
24+
def __init__(self,
25+
model_name: str,
26+
dtype: str = "auto",
27+
*args: Any,
28+
**kwargs: Any) -> None:
29+
from transformers import AutoModelForCausalLM, AutoTokenizer
30+
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
31+
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
32+
padding_side='left')
33+
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
34+
35+
@torch.no_grad()
36+
def predict(self, prompts: list[list[str]], *args,
37+
**kwargs) -> torch.Tensor:
38+
39+
def get_inputs(pairs, tokenizer, prompt=None):
40+
if prompt is None:
41+
prompt = PROMPT
42+
43+
sep = "\n"
44+
prompt_inputs = tokenizer(prompt,
45+
return_tensors=None,
46+
add_special_tokens=False)["input_ids"]
47+
sep_inputs = tokenizer(sep,
48+
return_tensors=None,
49+
add_special_tokens=False)["input_ids"]
50+
inputs = []
51+
for query, passage in pairs:
52+
query_inputs = tokenizer(
53+
f"A: {query}",
54+
return_tensors=None,
55+
add_special_tokens=False,
56+
truncation=True,
57+
)
58+
passage_inputs = tokenizer(
59+
f"B: {passage}",
60+
return_tensors=None,
61+
add_special_tokens=False,
62+
truncation=True,
63+
)
64+
item = tokenizer.prepare_for_model(
65+
[tokenizer.bos_token_id] + query_inputs["input_ids"],
66+
sep_inputs + passage_inputs["input_ids"],
67+
truncation="only_second",
68+
padding=False,
69+
return_attention_mask=False,
70+
return_token_type_ids=False,
71+
add_special_tokens=False,
72+
)
73+
item["input_ids"] = item[
74+
"input_ids"] + sep_inputs + prompt_inputs
75+
item["attention_mask"] = [1] * len(item["input_ids"])
76+
inputs.append(item)
77+
return tokenizer.pad(
78+
inputs,
79+
padding=True,
80+
return_tensors="pt",
81+
)
82+
83+
scores = []
84+
for query, doc, *_ in prompts:
85+
pairs = [(query, doc)]
86+
inputs = get_inputs(pairs, self.tokenizer)
87+
inputs = inputs.to(self.model.device)
88+
_n_tokens = inputs["input_ids"].shape[1]
89+
logits = self.model(**inputs, return_dict=True).logits
90+
_scores = (logits[:, -1,
91+
self.yes_loc].view(-1, ).float().sigmoid())
92+
scores.append(_scores[0].item())
93+
return torch.Tensor(scores)
94+
95+
96+
class GemmaMtebEncoder(VllmMtebEncoder):
97+
98+
def __init__(self, *args, **kwargs):
99+
super().__init__(*args, **kwargs)
100+
self.prompt = PROMPT
101+
self.query_template = "A: {query}\n"
102+
self.document_template = "B: {doc}\n{prompt}"
103+
104+
def predict(
105+
self,
106+
sentences: list[tuple[str, str,
107+
Optional[str]]], # query, corpus, prompt
108+
*args,
109+
**kwargs,
110+
) -> np.ndarray:
111+
112+
_sentences = []
113+
for query, corpus, prompt in sentences:
114+
query = self.query_template.format(query=query)
115+
corpus = self.document_template.format(doc=corpus, prompt=prompt)
116+
_sentences.append((query, corpus, prompt))
117+
118+
return super().predict(_sentences, *args, **kwargs)
119+
120+
121+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
122+
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
123+
monkeypatch) -> None:
124+
monkeypatch.setenv("VLLM_USE_V1", "0")
125+
126+
assert model_info.architecture == "GemmaForSequenceClassification"
127+
128+
vllm_extra_kwargs: dict[str, Any] = {
129+
"hf_overrides": {
130+
"architectures": ["GemmaForSequenceClassification"],
131+
"classifier_from_token": ["Yes"],
132+
"method": "no_post_processing",
133+
}
134+
}
135+
136+
mteb_test_rerank_models(GemmaRerankerHfRunner,
137+
vllm_runner,
138+
model_info,
139+
vllm_extra_kwargs,
140+
vllm_mteb_encoder=GemmaMtebEncoder)

tests/models/language/pooling/test_mxbai_rerank.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
RERANK_MODELS = [
1313
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
1414
architecture="Qwen2ForSequenceClassification",
15-
dtype="float32",
1615
enable_test=True),
1716
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
1817
architecture="Qwen2ForSequenceClassification",
19-
dtype="float32",
2018
enable_test=False)
2119
]
2220

tests/models/registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,14 @@ def check_available_online(
319319
_CROSS_ENCODER_EXAMPLE_MODELS = {
320320
# [Text-only]
321321
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
322+
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
323+
v0_only=True,
324+
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
325+
"classifier_from_token": ["Yes"], # noqa: E501
326+
"method": "no_post_processing"}), # noqa: E501
327+
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
322328
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
323329
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
324-
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
325330
}
326331

327332
_MULTIMODAL_EXAMPLE_MODELS = {

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,12 @@ def is_matryoshka(self) -> bool:
14491449
def matryoshka_dimensions(self):
14501450
return getattr(self.hf_config, "matryoshka_dimensions", None)
14511451

1452+
@property
1453+
def use_pad_token(self) -> bool:
1454+
# cross_encoder models defaults to using pad_token.
1455+
# `llm as reranker` models defaults to not using pad_token.
1456+
return getattr(self.hf_config, "use_pad_token", True)
1457+
14521458
def get_and_verify_max_len(self, max_model_len: int):
14531459
# For pooling models, the tokenizer's `model_max_length` is often a
14541460
# reliable source for the maximum sequence length. However, for

vllm/entrypoints/llm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,17 +1205,21 @@ def _cross_encoding_score(
12051205
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
12061206

12071207
pooling_params = PoolingParams(use_cross_encoder=True)
1208-
12091208
tokenization_kwargs: dict[str, Any] = {}
12101209
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
12111210
truncate_prompt_tokens, tokenization_kwargs)
12121211

12131212
parsed_prompts = []
12141213

12151214
for q, t in input_pairs:
1216-
prompt_inputs = tokenizer(text=q,
1217-
text_pair=t,
1218-
**tokenization_kwargs)
1215+
if self.llm_engine.model_config.use_pad_token:
1216+
# cross_encoder models defaults to using pad_token.
1217+
prompt_inputs = tokenizer(text=q,
1218+
text_pair=t,
1219+
**tokenization_kwargs)
1220+
else:
1221+
# `llm as reranker` models defaults to not using pad_token.
1222+
prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs)
12191223
engine_prompt = TokensPrompt(
12201224
prompt_token_ids=prompt_inputs["input_ids"],
12211225
token_type_ids=prompt_inputs.get("token_type_ids"))

0 commit comments

Comments
 (0)