-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Model][3/N] Automatic conversion of CrossEncoding model #20168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
33d1b15
fix
noooop c0487a3
fix NotImplementedError
noooop b4e5c5b
+ seq_cls_models_loader.py
noooop e37bf8e
+ test_mxbai_rerank.py
noooop 260566d
fix
noooop e601234
fix
noooop a2586b4
fix
noooop 9b51b17
fix
noooop de88e58
fix
noooop 971e413
fix
noooop d81c475
fix
noooop File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from typing import Any | ||
|
||
import pytest | ||
import torch | ||
|
||
from tests.conftest import HfRunner | ||
|
||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models | ||
|
||
RERANK_MODELS = [ | ||
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", | ||
architecture="Qwen2ForSequenceClassification", | ||
dtype="float32", | ||
enable_test=True), | ||
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", | ||
architecture="Qwen2ForSequenceClassification", | ||
dtype="float32", | ||
enable_test=False) | ||
] | ||
|
||
|
||
class MxbaiRerankerHfRunner(HfRunner): | ||
|
||
def __init__(self, | ||
model_name: str, | ||
dtype: str = "auto", | ||
*args: Any, | ||
**kwargs: Any) -> None: | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, | ||
padding_side='left') | ||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") | ||
self.no_loc = self.tokenizer.convert_tokens_to_ids("0") | ||
|
||
def predict(self, prompts: list[list[str]], *args, | ||
**kwargs) -> torch.Tensor: | ||
|
||
def process_inputs(pairs): | ||
inputs = self.tokenizer(pairs, | ||
padding=False, | ||
truncation='longest_first', | ||
return_attention_mask=False) | ||
for i, ele in enumerate(inputs['input_ids']): | ||
inputs['input_ids'][i] = ele | ||
inputs = self.tokenizer.pad(inputs, | ||
padding=True, | ||
return_tensors="pt") | ||
for key in inputs: | ||
inputs[key] = inputs[key].to(self.model.device) | ||
return inputs | ||
|
||
@torch.no_grad() | ||
def compute_logits(inputs): | ||
logits = self.model(**inputs).logits[:, -1, :] | ||
yes_logits = logits[:, self.yes_loc] | ||
no_logits = logits[:, self.no_loc] | ||
logits = yes_logits - no_logits | ||
scores = logits.float().sigmoid() | ||
return scores | ||
|
||
scores = [] | ||
for prompt in prompts: | ||
inputs = process_inputs([prompt]) | ||
score = compute_logits(inputs) | ||
scores.append(score[0].item()) | ||
return torch.Tensor(scores) | ||
|
||
|
||
@pytest.mark.parametrize("model_info", RERANK_MODELS) | ||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: | ||
vllm_extra_kwargs: dict[str, Any] = {} | ||
if model_info.architecture == "Qwen2ForSequenceClassification": | ||
vllm_extra_kwargs["hf_overrides"] = { | ||
"architectures": ["Qwen2ForSequenceClassification"], | ||
"classifier_from_token": ["0", "1"], | ||
"method": "from_2_way_softmax", | ||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, | ||
vllm_extra_kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.