-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Model][Last/4] Automatic conversion of CrossEncoding model #19675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
65d3099
support BAAI/bge-reranker-v2-gemma
noooop 711069f
fix
noooop 2353f2a
fix
noooop f398e0c
Merge branch 'vllm-project:main' into as_score_model
noooop 401b737
+ use_pad_token
noooop 33a3bb3
fix
noooop 49e0e4c
fix
noooop 9d45ec6
Merge branch 'vllm-project:main' into as_score_model
noooop 095659a
+ GemmaForSequenceClassification
noooop c917a6d
fix
noooop 59c1c22
fix
noooop 17e4ac8
fix
noooop File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
# ruff: noqa: E501 | ||
|
||
import argparse | ||
import json | ||
|
||
import torch | ||
import transformers | ||
|
||
# Usage: | ||
# for BAAI/bge-reranker-v2-gemma | ||
# Caution: "Yes" and "yes" are two different tokens | ||
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls | ||
# for mxbai-rerank-v2 | ||
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls | ||
# for Qwen3-Reranker | ||
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls | ||
|
||
|
||
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): | ||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 | ||
assert len(tokens) == 2 | ||
|
||
lm_head_weights = causal_lm.lm_head.weight | ||
|
||
false_id = tokenizer.convert_tokens_to_ids(tokens[0]) | ||
true_id = tokenizer.convert_tokens_to_ids(tokens[1]) | ||
|
||
score_weight = lm_head_weights[true_id].to(device).to( | ||
torch.float32 | ||
) - lm_head_weights[false_id].to(device).to(torch.float32) | ||
|
||
with torch.no_grad(): | ||
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) | ||
if seq_cls_model.score.bias is not None: | ||
seq_cls_model.score.bias.zero_() | ||
|
||
|
||
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device): | ||
lm_head_weights = causal_lm.lm_head.weight | ||
|
||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] | ||
|
||
score_weight = lm_head_weights[token_ids].to(device) | ||
|
||
with torch.no_grad(): | ||
seq_cls_model.score.weight.copy_(score_weight) | ||
if seq_cls_model.score.bias is not None: | ||
seq_cls_model.score.bias.zero_() | ||
|
||
|
||
method_map = { | ||
function.__name__: function for function in [from_2_way_softmax, no_post_processing] | ||
} | ||
|
||
|
||
def converting( | ||
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" | ||
): | ||
assert method in method_map | ||
|
||
if method == "from_2_way_softmax": | ||
assert len(classifier_from_tokens) == 2 | ||
num_labels = 1 | ||
else: | ||
num_labels = len(classifier_from_tokens) | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | ||
causal_lm = transformers.AutoModelForCausalLM.from_pretrained( | ||
model_name, device_map=device | ||
) | ||
|
||
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( | ||
model_name, | ||
num_labels=num_labels, | ||
ignore_mismatched_sizes=True, | ||
device_map=device, | ||
) | ||
|
||
method_map[method]( | ||
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device | ||
) | ||
|
||
# `llm as reranker` defaults to not using pad_token | ||
seq_cls_model.config.use_pad_token = use_pad_token | ||
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id | ||
|
||
seq_cls_model.save_pretrained(path) | ||
tokenizer.save_pretrained(path) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Converting *ForCausalLM models to " | ||
"*ForSequenceClassification models." | ||
) | ||
parser.add_argument( | ||
"--model_name", | ||
type=str, | ||
default="BAAI/bge-reranker-v2-gemma", | ||
help="Model name", | ||
) | ||
parser.add_argument( | ||
"--classifier_from_tokens", | ||
type=str, | ||
default='["Yes"]', | ||
help="classifier from tokens", | ||
) | ||
parser.add_argument( | ||
"--method", type=str, default="no_post_processing", help="Converting converting" | ||
) | ||
parser.add_argument( | ||
"--use-pad-token", action="store_true", help="Whether to use pad_token" | ||
) | ||
parser.add_argument( | ||
"--path", | ||
type=str, | ||
default="./bge-reranker-v2-gemma-seq-cls", | ||
help="Path to save converted model", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
converting( | ||
model_name=args.model_name, | ||
classifier_from_tokens=json.loads(args.classifier_from_tokens), | ||
method=args.method, | ||
use_pad_token=args.use_pad_token, | ||
path=args.path, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from tests.conftest import HfRunner | ||
|
||
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, | ||
mteb_test_rerank_models) | ||
|
||
RERANK_MODELS = [ | ||
RerankModelInfo("BAAI/bge-reranker-v2-gemma", | ||
architecture="GemmaForSequenceClassification"), | ||
] | ||
|
||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 | ||
|
||
|
||
class GemmaRerankerHfRunner(HfRunner): | ||
|
||
def __init__(self, | ||
model_name: str, | ||
dtype: str = "auto", | ||
*args: Any, | ||
**kwargs: Any) -> None: | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, | ||
padding_side='left') | ||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") | ||
|
||
@torch.no_grad() | ||
def predict(self, prompts: list[list[str]], *args, | ||
**kwargs) -> torch.Tensor: | ||
|
||
def get_inputs(pairs, tokenizer, prompt=None): | ||
if prompt is None: | ||
prompt = PROMPT | ||
|
||
sep = "\n" | ||
prompt_inputs = tokenizer(prompt, | ||
return_tensors=None, | ||
add_special_tokens=False)["input_ids"] | ||
sep_inputs = tokenizer(sep, | ||
return_tensors=None, | ||
add_special_tokens=False)["input_ids"] | ||
inputs = [] | ||
for query, passage in pairs: | ||
query_inputs = tokenizer( | ||
f"A: {query}", | ||
return_tensors=None, | ||
add_special_tokens=False, | ||
truncation=True, | ||
) | ||
passage_inputs = tokenizer( | ||
f"B: {passage}", | ||
return_tensors=None, | ||
add_special_tokens=False, | ||
truncation=True, | ||
) | ||
item = tokenizer.prepare_for_model( | ||
[tokenizer.bos_token_id] + query_inputs["input_ids"], | ||
sep_inputs + passage_inputs["input_ids"], | ||
truncation="only_second", | ||
padding=False, | ||
return_attention_mask=False, | ||
return_token_type_ids=False, | ||
add_special_tokens=False, | ||
) | ||
item["input_ids"] = item[ | ||
"input_ids"] + sep_inputs + prompt_inputs | ||
item["attention_mask"] = [1] * len(item["input_ids"]) | ||
inputs.append(item) | ||
return tokenizer.pad( | ||
inputs, | ||
padding=True, | ||
return_tensors="pt", | ||
) | ||
|
||
scores = [] | ||
for query, doc, *_ in prompts: | ||
pairs = [(query, doc)] | ||
inputs = get_inputs(pairs, self.tokenizer) | ||
inputs = inputs.to(self.model.device) | ||
_n_tokens = inputs["input_ids"].shape[1] | ||
logits = self.model(**inputs, return_dict=True).logits | ||
_scores = (logits[:, -1, | ||
self.yes_loc].view(-1, ).float().sigmoid()) | ||
scores.append(_scores[0].item()) | ||
return torch.Tensor(scores) | ||
|
||
|
||
class GemmaMtebEncoder(VllmMtebEncoder): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.prompt = PROMPT | ||
self.query_template = "A: {query}\n" | ||
self.document_template = "B: {doc}\n{prompt}" | ||
|
||
def predict( | ||
self, | ||
sentences: list[tuple[str, str, | ||
Optional[str]]], # query, corpus, prompt | ||
*args, | ||
**kwargs, | ||
) -> np.ndarray: | ||
|
||
_sentences = [] | ||
for query, corpus, prompt in sentences: | ||
query = self.query_template.format(query=query) | ||
corpus = self.document_template.format(doc=corpus, prompt=prompt) | ||
_sentences.append((query, corpus, prompt)) | ||
|
||
return super().predict(_sentences, *args, **kwargs) | ||
|
||
|
||
@pytest.mark.parametrize("model_info", RERANK_MODELS) | ||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, | ||
monkeypatch) -> None: | ||
monkeypatch.setenv("VLLM_USE_V1", "0") | ||
|
||
assert model_info.architecture == "GemmaForSequenceClassification" | ||
|
||
vllm_extra_kwargs: dict[str, Any] = { | ||
"hf_overrides": { | ||
"architectures": ["GemmaForSequenceClassification"], | ||
"classifier_from_token": ["Yes"], | ||
"method": "no_post_processing", | ||
} | ||
} | ||
|
||
mteb_test_rerank_models(GemmaRerankerHfRunner, | ||
vllm_runner, | ||
model_info, | ||
vllm_extra_kwargs, | ||
vllm_mteb_encoder=GemmaMtebEncoder) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.