Skip to content

Commit adf3d36

Browse files
committed
+ converting2seq_cls_models.py
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 4ff22ba commit adf3d36

File tree

5 files changed

+110
-32
lines changed

5 files changed

+110
-32
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
4+
5+
import argparse
6+
import json
7+
8+
import torch
9+
import transformers
10+
11+
12+
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer,
13+
classifier_from_tokens, device):
14+
# for Qwen3-Reranker
15+
assert len(classifier_from_tokens) == 2
16+
17+
lm_head_weights = causal_lm.lm_head.weight
18+
19+
a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0])
20+
b = tokenizer.convert_tokens_to_ids(classifier_from_tokens[1])
21+
22+
score_weight = lm_head_weights[b].to(torch.float32).to(device).to(
23+
torch.float32) - lm_head_weights[a].to(device)
24+
25+
with torch.no_grad():
26+
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
27+
if seq_cls_model.score.bias is not None:
28+
seq_cls_model.score.bias.zero_()
29+
30+
31+
method_map = {function.__name__: function for function in [from_2_way_softmax]}
32+
33+
34+
def converting(model_name, classifier_from_tokens, path, method, device="cpu"):
35+
assert method in method_map
36+
37+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
38+
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
39+
model_name, device_map=device)
40+
41+
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
42+
model_name,
43+
num_labels=1,
44+
ignore_mismatched_sizes=True,
45+
device_map=device)
46+
47+
method_map[method](causal_lm, seq_cls_model, tokenizer,
48+
classifier_from_tokens, device)
49+
50+
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
51+
52+
seq_cls_model.save_pretrained(path)
53+
tokenizer.save_pretrained(path)
54+
55+
56+
def parse_args():
57+
parser = argparse.ArgumentParser(
58+
description=
59+
"Converting *ForCausalLM models to *ForSequenceClassification models.")
60+
parser.add_argument("--model_name",
61+
type=str,
62+
default="Qwen/Qwen3-Reranker-0.6B",
63+
help="Model name")
64+
parser.add_argument("--classifier_from_tokens",
65+
type=str,
66+
default='["no", "yes"]',
67+
help="classifier from tokens")
68+
parser.add_argument("--method",
69+
type=str,
70+
default='from_2_way_softmax',
71+
help="Converting converting")
72+
parser.add_argument("--path",
73+
type=str,
74+
default="./converted_model",
75+
help="Path to save converted model")
76+
return parser.parse_args()
77+
78+
79+
if __name__ == "__main__":
80+
args = parse_args()
81+
82+
converting(model_name=args.model_name,
83+
classifier_from_tokens=json.loads(args.classifier_from_tokens),
84+
method=args.method,
85+
path=args.path)

examples/offline_inference/qwen3_reranker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
# concise, for example.
2020
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
2121

22+
# Offline conversion from official original version to sequence classification
23+
# model code please refer to: converting2seq_cls_models.py
24+
# The init parameters are as follows.
25+
# model = LLM(model="path_to/converted_model", task="score")
26+
2227
# If you want to load the official original version, the init parameters are
2328
# as follows.
2429

tests/models/language/pooling/test_gte.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import pytest
66

7+
from ...utils import RerankModelInfo
78
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
8-
from .mteb_utils import mteb_test_embed_models
9+
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
910

1011
MODELS = [
1112
########## BertModel
@@ -56,6 +57,12 @@
5657
enable_test=False),
5758
]
5859

60+
RERANK_MODELS = [
61+
RerankModelInfo("Alibaba-NLP/gte-reranker-modernbert-base",
62+
architecture="ModernBertForSequenceClassification",
63+
enable_test=False),
64+
]
65+
5966

6067
@pytest.mark.parametrize("model_info", MODELS)
6168
def test_embed_models_mteb(hf_runner, vllm_runner,
@@ -80,3 +87,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
8087

8188
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
8289
example_prompts, vllm_extra_kwargs)
90+
91+
92+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
93+
def test_rerank_models_mteb(hf_runner, vllm_runner,
94+
model_info: RerankModelInfo) -> None:
95+
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)

vllm/model_executor/models/qwen3.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@
4141
from vllm.model_executor.layers.quantization import QuantizationConfig
4242
from vllm.model_executor.layers.rotary_embedding import get_rope
4343
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
44-
from vllm.model_executor.pooling_metadata import PoolingMetadata
4544
from vllm.model_executor.sampling_metadata import SamplingMetadata
46-
from vllm.sequence import IntermediateTensors, PoolerOutput
45+
from vllm.sequence import IntermediateTensors
4746

4847
from .adapters import as_seq_cls_model
4948
from .interfaces import SupportsLoRA, SupportsPP
@@ -349,32 +348,6 @@ def config_verify(self, vllm_config: "VllmConfig"):
349348
config.num_labels = 1
350349
self.vllm_config = vllm_config
351350

352-
def forward(
353-
self,
354-
input_ids: torch.Tensor,
355-
positions: torch.Tensor,
356-
intermediate_tensors: Optional[IntermediateTensors] = None,
357-
inputs_embeds: Optional[torch.Tensor] = None,
358-
) -> torch.Tensor:
359-
return self.model(input_ids=input_ids,
360-
positions=positions,
361-
inputs_embeds=inputs_embeds,
362-
intermediate_tensors=intermediate_tensors)
363-
364-
def pooler(
365-
self,
366-
hidden_states: torch.Tensor,
367-
pooling_metadata: PoolingMetadata,
368-
) -> Optional[PoolerOutput]:
369-
hidden_states = self._pooler.extract_states(hidden_states,
370-
pooling_metadata)
371-
logits, _ = self.score(hidden_states)
372-
pooled_data = self._pooler.head(logits, pooling_metadata)
373-
pooled_outputs = [
374-
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
375-
]
376-
return PoolerOutput(outputs=pooled_outputs)
377-
378351
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
379352
is_original_qwen3_reranker = getattr(self.config,
380353
"is_original_qwen3_reranker",
@@ -419,5 +392,6 @@ def load_weights_from_original_qwen3_reranker(
419392
self.score.weight.data.copy_(weight)
420393

421394
del self.lm_head
422-
loaded_weights.add("classifier.weight")
395+
loaded_weights.add("score.weight")
423396
loaded_weights.discard("lm_head.weight")
397+
return loaded_weights

vllm/model_executor/models/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@
157157
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
158158
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
159159
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
160-
# [Auto-converted (see adapters.py)]
161-
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
162160
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
163161
# input and output. I am adding it here because it piggy-backs on embedding
164162
# models for the time being.
@@ -173,7 +171,10 @@
173171
"RobertaForSequenceClassification"),
174172
"ModernBertForSequenceClassification": ("modernbert",
175173
"ModernBertForSequenceClassification"),
174+
# [Auto-converted (see adapters.py)]
175+
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), # noqa: E501
176176
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
177+
"GemmaForSequenceClassification": ("gemma", "GemmaForCausalLM"),
177178
}
178179

179180
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)