From 5c57bfa92d4ae59c04c95fd8eac9f34f70a5ff16 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 15:19:17 +0800 Subject: [PATCH 01/25] very dirty fix --- vllm/model_executor/models/qwen3.py | 40 ++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 393ce41a91a0..bf650c7cca9b 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -38,13 +38,16 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import LastPool from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix @@ -245,7 +248,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): decoder_layer_type=Qwen3DecoderLayer) -class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + SupportsCrossEncoding): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -288,6 +292,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + if vllm_config.model_config.task == "score": + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self._pooler = LastPool(normalize=False, softmax=False) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -311,6 +321,30 @@ def compute_logits( sampling_metadata) return logits + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + hidden_states = self._pooler.extract_states(hidden_states, + pooling_metadata) + logits = self.logits_processor._get_logits(hidden_states=hidden_states, + lm_head=self.lm_head, + embedding_bias=None) + + token_false_id = 2152 + token_true_id = 9693 + + true_vector = logits[:, token_true_id] + false_vector = logits[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp() + + pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] + return PoolerOutput(outputs=pooled_outputs) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( From 4dddb99bcc7d0a863f2a2797e91684a55cb8a41d Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 15:59:04 +0800 Subject: [PATCH 02/25] + Qwen3ForSequenceClassification --- vllm/model_executor/models/qwen3.py | 112 ++++++++++++++++++++----- vllm/model_executor/models/registry.py | 1 + 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index bf650c7cca9b..b55534f1291f 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -248,8 +248,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): decoder_layer_type=Qwen3DecoderLayer) -class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsCrossEncoding): +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -292,12 +291,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - if vllm_config.model_config.task == "score": - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self._pooler = LastPool(normalize=False, softmax=False) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -321,6 +314,81 @@ def compute_logits( sampling_metadata) return logits + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP, + SupportsCrossEncoding): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.score = RowParallelLinear(config.hidden_size, + getattr(config, "num_labels", 2), + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix(prefix, "score")) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self._pooler = LastPool(normalize=False, softmax=False) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + def pooler( self, hidden_states: torch.Tensor, @@ -328,18 +396,8 @@ def pooler( ) -> PoolerOutput: hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) - logits = self.logits_processor._get_logits(hidden_states=hidden_states, - lm_head=self.lm_head, - embedding_bias=None) - - token_false_id = 2152 - token_true_id = 9693 - - true_vector = logits[:, token_true_id] - false_vector = logits[:, token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1) - - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + logits, _ = self.score(hidden_states) + batch_scores = torch.nn.functional.log_softmax(logits, dim=1) scores = batch_scores[:, 1].exp() pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] @@ -352,4 +410,16 @@ def load_weights(self, weights: Iterable[tuple[str, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + + loaded_weights = loader.load_weights(weights) + + TOKEN_FALSE_ID = 2152 + TOKEN_TRUE_ID = 9693 + + self.score.weight.data[0] = self.lm_head.weight.data[TOKEN_FALSE_ID] + self.score.weight.data[1] = self.lm_head.weight.data[TOKEN_TRUE_ID] + + del self.lm_head + + loaded_weights.add("score.weight") + return loaded_weights diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e82e36638069..d28d2466bb6b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -172,6 +172,7 @@ "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), + "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 } _MULTIMODAL_MODELS = { From 65646a1316c4f28f7967e8dadf40b90292a447fb Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 16:30:35 +0800 Subject: [PATCH 03/25] + classifier_from_token --- vllm/model_executor/models/qwen3.py | 74 +++++++++++++++++------------ 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index b55534f1291f..f76d769db29e 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -343,36 +343,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None self.config = config + self.model_config = vllm_config.model_config self.lora_config = lora_config - self.quant_config = quant_config self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.score = RowParallelLinear(config.hidden_size, - getattr(config, "num_labels", 2), - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix(prefix, "score")) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - + self.prefix = prefix + self.num_labels = getattr(config, "num_labels", 2) + self.classifier = RowParallelLinear(config.hidden_size, + self.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + prefix, "classifier")) self._pooler = LastPool(normalize=False, softmax=False) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -396,7 +384,7 @@ def pooler( ) -> PoolerOutput: hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) - logits, _ = self.score(hidden_states) + logits, _ = self.classifier(hidden_states) batch_scores = torch.nn.functional.log_softmax(logits, dim=1) scores = batch_scores[:, 1].exp() @@ -405,6 +393,20 @@ def pooler( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + tokens = getattr(self.config, "classifier_from_token", None) + if tokens is not None: + if get_pp_group().is_last_rank: + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(self.prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] @@ -413,13 +415,23 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_weights = loader.load_weights(weights) - TOKEN_FALSE_ID = 2152 - TOKEN_TRUE_ID = 9693 + if tokens is None: + return loaded_weights - self.score.weight.data[0] = self.lm_head.weight.data[TOKEN_FALSE_ID] - self.score.weight.data[1] = self.lm_head.weight.data[TOKEN_TRUE_ID] + assert len(tokens) == self.num_labels - del self.lm_head + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer( + self.model_config.tokenizer, + revision=self.model_config.tokenizer_revision, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code) + + for i, token in enumerate(tokens): + token_id = tokenizer(token, add_special_tokens=False).input_ids[0] + print(token_id) + self.classifier.weight.data[i] = self.lm_head.weight.data[token_id] - loaded_weights.add("score.weight") + del self.lm_head + loaded_weights.add("classifier.weight") return loaded_weights From b36f788cca7a1d46cc8def500334911c261e5a93 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 17:20:34 +0800 Subject: [PATCH 04/25] + tests --- .../language/pooling/test_qwen3_reranker.py | 86 +++++++++++++++++++ vllm/model_executor/models/qwen3.py | 3 +- 2 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 tests/models/language/pooling/test_qwen3_reranker.py diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py new file mode 100644 index 000000000000..9829aa398a29 --- /dev/null +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +model_name = "Qwen/Qwen3-Reranker-4B" + +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", +] + + +def vllm_reranker(model_name): + from vllm import LLM + + model = LLM(model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"] + }, + dtype="float32") + + text_1 = "What is the capital of France?" + texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + ] + + outputs = model.score(text_1, texts_2) + + return [output.outputs.score for output in outputs] + + +def hf_reranker(model_name): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') + model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + token_false_id = tokenizer.convert_tokens_to_ids("no") + token_true_id = tokenizer.convert_tokens_to_ids("yes") + + max_length = 8192 + + def process_inputs(pairs): + inputs = tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=max_length) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = ele + inputs = tokenizer.pad(inputs, + padding=True, + return_tensors="pt", + max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] + inputs = process_inputs(pairs) + scores = compute_logits(inputs) + + return scores + + +@pytest.mark.parametrize("model_name", [model_name]) +def test_model(model_name): + hf_outputs = hf_reranker(model_name) + vllm_outputs = vllm_reranker(model_name) + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index f76d769db29e..5fdc728ec6ec 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -428,8 +428,7 @@ def load_weights(self, weights: Iterable[tuple[str, trust_remote_code=self.model_config.trust_remote_code) for i, token in enumerate(tokens): - token_id = tokenizer(token, add_special_tokens=False).input_ids[0] - print(token_id) + token_id = tokenizer.convert_tokens_to_ids(token) self.classifier.weight.data[i] = self.lm_head.weight.data[token_id] del self.lm_head From 90660fca5199f353e586cb8c683a450720f6c78b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 17:41:18 +0800 Subject: [PATCH 05/25] + Embedding tests --- tests/models/language/pooling/test_gte.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 05bd479f42b9..c6c060519baa 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -45,6 +45,13 @@ EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", architecture="ModernBertModel", enable_test=True), + ########## Qwen3ForCausalLM + EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", + architecture="Qwen3ForCausalLM", + enable_test=True), + EmbedModelInfo("Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + enable_test=True), ] From a44fb4f07d3ae516783487dd11ad9cd8687c6d24 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 18:00:41 +0800 Subject: [PATCH 06/25] Qwen/Qwen3-Embedding-0.6B needs float32 to pass the test --- tests/models/language/pooling/test_gte.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index c6c060519baa..18f2022537a4 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -48,6 +48,7 @@ ########## Qwen3ForCausalLM EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", architecture="Qwen3ForCausalLM", + dtype="float32", enable_test=True), EmbedModelInfo("Qwen/Qwen3-Embedding-4B", architecture="Qwen3ForCausalLM", From 4ef23af1a1604af2b5061621ac6674d49c18be3f Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 18:47:41 +0800 Subject: [PATCH 07/25] + process_inputs --- test/__init__.py | 0 test/st.py | 80 +++++++++++++++++++++++++++++ test/vllm.py | 53 +++++++++++++++++++ vllm/entrypoints/llm.py | 21 +++++--- vllm/model_executor/models/qwen3.py | 4 +- 5 files changed, 149 insertions(+), 9 deletions(-) create mode 100644 test/__init__.py create mode 100644 test/st.py create mode 100644 test/vllm.py diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/st.py b/test/st.py new file mode 100644 index 000000000000..d4ead1d76f4a --- /dev/null +++ b/test/st.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = 'Given a web search query, retrieve relevant passages that answer the query' + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc) + return output + + +def process_inputs(pairs): + inputs = tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=max_length - len(prefix_tokens) - + len(suffix_tokens)) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens + inputs = tokenizer.pad(inputs, + padding=True, + return_tensors="pt", + max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + +def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + +model_name = "Qwen/Qwen3-Reranker-4B" + +tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') +model = AutoModelForCausalLM.from_pretrained(model_name).eval() + +token_false_id = tokenizer.convert_tokens_to_ids("no") +token_true_id = tokenizer.convert_tokens_to_ids("yes") +max_length = 8192 + +prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) +suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + +if __name__ == '__main__': + + task = 'Given a web search query, retrieve relevant passages that answer the query' + + queries = [ + "What is the capital of China?", + "Explain gravity", + ] + + documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", + ] + + pairs = [ + format_instruction(task, query, doc) + for query, doc in zip(queries, documents) + ] + + # Tokenize the input texts + inputs = process_inputs(pairs) + scores = compute_logits(inputs) + + print("scores: ", scores) diff --git a/test/vllm.py b/test/vllm.py new file mode 100644 index 000000000000..25c1989a5574 --- /dev/null +++ b/test/vllm.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 + +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B") + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = 'Given a web search query, retrieve relevant passages that answer the query' + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc) + return output + + +max_length = 8192 + +prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + +instruction = 'Given a web search query, retrieve relevant passages that answer the query' + + +def process_inputs(query, doc): + messages = format_instruction(instruction, query, doc) + messages = prefix + messages + suffix + return messages + + +if __name__ == '__main__': + from vllm import LLM + + model = LLM(model="Qwen/Qwen3-Reranker-4B", + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"] + }) + + queries = [ + "What is the capital of China?", + "Explain gravity", + ] + + documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", + ] + + outputs = model.score(queries, documents, process_inputs=process_inputs) + + print([output.outputs.score for output in outputs]) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fd28bf39e2d5..8bdb9ae54ea2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1130,6 +1130,7 @@ def _cross_encoding_score( use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + process_inputs: Optional[Callable] = None, ) -> list[ScoringRequestOutput]: if isinstance(tokenizer, MistralTokenizer): @@ -1150,12 +1151,17 @@ def _cross_encoding_score( parsed_prompts = [] for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + if process_inputs is not None: + text = process_inputs(q, t) + prompt_inputs = tokenizer(text=text, **tokenization_kwargs) + else: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( @@ -1182,6 +1188,7 @@ def score( use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + process_inputs: Optional[Callable] = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs ``. @@ -1258,11 +1265,9 @@ def ensure_str(prompt: SingletonPrompt): _validate_score_input_lens(input_text_1, input_text_2) if self.llm_engine.model_config.is_cross_encoder: - return self._cross_encoding_score(tokenizer, input_text_1, - input_text_2, - truncate_prompt_tokens, use_tqdm, - lora_request, - prompt_adapter_request) + return self._cross_encoding_score( + tokenizer, input_text_1, input_text_2, truncate_prompt_tokens, + use_tqdm, lora_request, prompt_adapter_request, process_inputs) else: return self._embedding_score( tokenizer, diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 5fdc728ec6ec..1882d23b4c45 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -385,7 +385,9 @@ def pooler( hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) logits, _ = self.classifier(hidden_states) - batch_scores = torch.nn.functional.log_softmax(logits, dim=1) + batch_scores = torch.nn.functional.log_softmax(logits.to( + torch.float32), + dim=1) scores = batch_scores[:, 1].exp() pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] From 2fe1791a914ca5b4c308e365ef6e85fc71e89e81 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 19:33:41 +0800 Subject: [PATCH 08/25] + 2-way -> 1-way --- test/{vllm.py => score.py} | 0 vllm/model_executor/models/qwen3.py | 31 +++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 8 deletions(-) rename test/{vllm.py => score.py} (100%) diff --git a/test/vllm.py b/test/score.py similarity index 100% rename from test/vllm.py rename to test/score.py diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 1882d23b4c45..5cbb8d20cc7a 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -354,6 +354,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.prefix = prefix self.num_labels = getattr(config, "num_labels", 2) + if self.num_labels == 2: + self.num_labels = 1 + self.classifier = RowParallelLinear(config.hidden_size, self.num_labels, quant_config=quant_config, @@ -385,10 +388,14 @@ def pooler( hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) logits, _ = self.classifier(hidden_states) - batch_scores = torch.nn.functional.log_softmax(logits.to( - torch.float32), - dim=1) - scores = batch_scores[:, 1].exp() + + if self.num_labels == 1: + scores = logits.squeeze(-1).to(torch.float32).sigmoid() + else: + batch_scores = torch.nn.functional.log_softmax(logits.to( + torch.float32), + dim=1) + scores = batch_scores[:, 1].exp() pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] return PoolerOutput(outputs=pooled_outputs) @@ -420,7 +427,8 @@ def load_weights(self, weights: Iterable[tuple[str, if tokens is None: return loaded_weights - assert len(tokens) == self.num_labels + assert len(tokens) == self.num_labels or (len(tokens) == 2 + and self.num_labels == 1) from vllm.transformers_utils.tokenizer import get_tokenizer tokenizer = get_tokenizer( @@ -429,9 +437,16 @@ def load_weights(self, weights: Iterable[tuple[str, tokenizer_mode=self.model_config.tokenizer_mode, trust_remote_code=self.model_config.trust_remote_code) - for i, token in enumerate(tokens): - token_id = tokenizer.convert_tokens_to_ids(token) - self.classifier.weight.data[i] = self.lm_head.weight.data[token_id] + if len(tokens) == 2: + a = tokenizer.convert_tokens_to_ids(tokens[0]) + b = tokenizer.convert_tokens_to_ids(tokens[1]) + self.classifier.weight.data = self.lm_head.weight.data[ + b] - self.lm_head.weight.data[a] + else: + for i, token in enumerate(tokens): + token_id = tokenizer.convert_tokens_to_ids(token) + self.classifier.weight.data[i] = self.lm_head.weight.data[ + token_id] del self.lm_head loaded_weights.add("classifier.weight") From cb4c02322c9586693d51dacb2d372207ea68f5b2 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 20:09:57 +0800 Subject: [PATCH 09/25] - process_inputs --- test/__init__.py | 0 test/score.py | 53 --------------------------- test/st.py | 80 ----------------------------------------- vllm/entrypoints/llm.py | 14 +++----- 4 files changed, 4 insertions(+), 143 deletions(-) delete mode 100644 test/__init__.py delete mode 100644 test/score.py delete mode 100644 test/st.py diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/score.py b/test/score.py deleted file mode 100644 index 25c1989a5574..000000000000 --- a/test/score.py +++ /dev/null @@ -1,53 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# ruff: noqa: E501 - -from transformers import AutoTokenizer - -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B") - - -def format_instruction(instruction, query, doc): - if instruction is None: - instruction = 'Given a web search query, retrieve relevant passages that answer the query' - output = ": {instruction}\n: {query}\n: {doc}".format( - instruction=instruction, query=query, doc=doc) - return output - - -max_length = 8192 - -prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" -suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" - -instruction = 'Given a web search query, retrieve relevant passages that answer the query' - - -def process_inputs(query, doc): - messages = format_instruction(instruction, query, doc) - messages = prefix + messages + suffix - return messages - - -if __name__ == '__main__': - from vllm import LLM - - model = LLM(model="Qwen/Qwen3-Reranker-4B", - task="score", - hf_overrides={ - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"] - }) - - queries = [ - "What is the capital of China?", - "Explain gravity", - ] - - documents = [ - "The capital of China is Beijing.", - "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", - ] - - outputs = model.score(queries, documents, process_inputs=process_inputs) - - print([output.outputs.score for output in outputs]) diff --git a/test/st.py b/test/st.py deleted file mode 100644 index d4ead1d76f4a..000000000000 --- a/test/st.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# ruff: noqa: E501 -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - - -def format_instruction(instruction, query, doc): - if instruction is None: - instruction = 'Given a web search query, retrieve relevant passages that answer the query' - output = ": {instruction}\n: {query}\n: {doc}".format( - instruction=instruction, query=query, doc=doc) - return output - - -def process_inputs(pairs): - inputs = tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False, - max_length=max_length - len(prefix_tokens) - - len(suffix_tokens)) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens - inputs = tokenizer.pad(inputs, - padding=True, - return_tensors="pt", - max_length=max_length) - for key in inputs: - inputs[key] = inputs[key].to(model.device) - return inputs - - -def compute_logits(inputs, **kwargs): - batch_scores = model(**inputs).logits[:, -1, :] - true_vector = batch_scores[:, token_true_id] - false_vector = batch_scores[:, token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1) - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - scores = batch_scores[:, 1].exp().tolist() - return scores - - -model_name = "Qwen/Qwen3-Reranker-4B" - -tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') -model = AutoModelForCausalLM.from_pretrained(model_name).eval() - -token_false_id = tokenizer.convert_tokens_to_ids("no") -token_true_id = tokenizer.convert_tokens_to_ids("yes") -max_length = 8192 - -prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" -suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" -prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) -suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) - -if __name__ == '__main__': - - task = 'Given a web search query, retrieve relevant passages that answer the query' - - queries = [ - "What is the capital of China?", - "Explain gravity", - ] - - documents = [ - "The capital of China is Beijing.", - "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", - ] - - pairs = [ - format_instruction(task, query, doc) - for query, doc in zip(queries, documents) - ] - - # Tokenize the input texts - inputs = process_inputs(pairs) - scores = compute_logits(inputs) - - print("scores: ", scores) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8bdb9ae54ea2..af80cb2ef464 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1130,7 +1130,6 @@ def _cross_encoding_score( use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - process_inputs: Optional[Callable] = None, ) -> list[ScoringRequestOutput]: if isinstance(tokenizer, MistralTokenizer): @@ -1151,13 +1150,9 @@ def _cross_encoding_score( parsed_prompts = [] for q, t in input_pairs: - if process_inputs is not None: - text = process_inputs(q, t) - prompt_inputs = tokenizer(text=text, **tokenization_kwargs) - else: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) @@ -1188,7 +1183,6 @@ def score( use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - process_inputs: Optional[Callable] = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs ``. @@ -1267,7 +1261,7 @@ def ensure_str(prompt: SingletonPrompt): if self.llm_engine.model_config.is_cross_encoder: return self._cross_encoding_score( tokenizer, input_text_1, input_text_2, truncate_prompt_tokens, - use_tqdm, lora_request, prompt_adapter_request, process_inputs) + use_tqdm, lora_request, prompt_adapter_request) else: return self._embedding_score( tokenizer, From 2a32e0f1bf4ad4d35b5b5c988009b0e1825b6457 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 20:11:16 +0800 Subject: [PATCH 10/25] - process_inputs --- vllm/entrypoints/llm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index af80cb2ef464..2bf579e96d35 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1259,9 +1259,11 @@ def ensure_str(prompt: SingletonPrompt): _validate_score_input_lens(input_text_1, input_text_2) if self.llm_engine.model_config.is_cross_encoder: - return self._cross_encoding_score( - tokenizer, input_text_1, input_text_2, truncate_prompt_tokens, - use_tqdm, lora_request, prompt_adapter_request) + return self._cross_encoding_score(tokenizer, input_text_1, + input_text_2, + truncate_prompt_tokens, use_tqdm, + lora_request, + prompt_adapter_request) else: return self._embedding_score( tokenizer, From caa6675b2cafc2e78dbca6cc8cf9f8329b1365f1 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 6 Jun 2025 20:15:53 +0800 Subject: [PATCH 11/25] + registry --- tests/models/registry.py | 2 ++ vllm/entrypoints/llm.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index e6543c197348..edef55f15db4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -238,6 +238,8 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "Qwen3ForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-Reranker-0.6B", # noqa: E501 + hf_overrides={"architectures": ["Qwen3ForSequenceClassification"]}), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 v0_only=True), diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2bf579e96d35..fd28bf39e2d5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1156,7 +1156,6 @@ def _cross_encoding_score( engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) self._validate_and_add_requests( From fccddef3ef9b5733f728861e3262b90e3261f328 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 15:49:05 +0800 Subject: [PATCH 12/25] refactor --- .../language/pooling/test_qwen3_reranker.py | 6 +- .../pooling/test_qwen3_reranker_seq_cls.py | 73 +++++++++ tests/models/registry.py | 3 +- vllm/model_executor/models/qwen3.py | 148 ++++++++---------- 4 files changed, 140 insertions(+), 90 deletions(-) create mode 100644 tests/models/language/pooling/test_qwen3_reranker_seq_cls.py diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 9829aa398a29..8f2f73cf4f4f 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -17,9 +17,9 @@ def vllm_reranker(model_name): task="score", hf_overrides={ "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"] - }, - dtype="float32") + "classifier_from_token": ["no", "yes"], + "is_qwen3_rerank": True, + }, dtype="float32") text_1 = "What is the capital of France?" texts_2 = [ diff --git a/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py b/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py new file mode 100644 index 000000000000..ee07f6ff9dca --- /dev/null +++ b/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", +] + + +def vllm_reranker(model_name): + from vllm import LLM + + model = LLM(model=model_name, task="score") + outputs = model.score(text_1, texts_2) + + return [output.outputs.score for output in outputs] + + +def hf_reranker(model_name): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') + model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + token_false_id = tokenizer.convert_tokens_to_ids("no") + token_true_id = tokenizer.convert_tokens_to_ids("yes") + + max_length = 8192 + + def process_inputs(pairs): + inputs = tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=max_length) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = ele + inputs = tokenizer.pad(inputs, + padding=True, + return_tensors="pt", + max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] + inputs = process_inputs(pairs) + scores = compute_logits(inputs) + + return scores + + +@pytest.mark.parametrize("model_name", [model_name]) +def test_model(model_name): + hf_outputs = hf_reranker(model_name) + vllm_outputs = vllm_reranker(model_name) + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/models/registry.py b/tests/models/registry.py index edef55f15db4..ea1e4a1ad2fb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -238,8 +238,7 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "Qwen3ForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-Reranker-0.6B", # noqa: E501 - hf_overrides={"architectures": ["Qwen3ForSequenceClassification"]}), # noqa: E501 + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 v0_only=True), diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 5cbb8d20cc7a..1f51b5fc2c8a 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -38,14 +38,13 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import LastPool +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (IntermediateTensors, PoolerOutput, - PoolingSequenceGroupOutput) +from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP @@ -326,48 +325,36 @@ def load_weights(self, weights: Iterable[tuple[str, class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP, SupportsCrossEncoding): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None + self.vllm_config = vllm_config self.config = config - self.model_config = vllm_config.model_config - self.lora_config = lora_config - self.quant_config = quant_config + self.maybe_is_qwen3_rerank() + self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.prefix = prefix - self.num_labels = getattr(config, "num_labels", 2) - if self.num_labels == 2: - self.num_labels = 1 - - self.classifier = RowParallelLinear(config.hidden_size, - self.num_labels, - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix( - prefix, "classifier")) - self._pooler = LastPool(normalize=False, softmax=False) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix(prefix, "score")) + + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) def forward( self, @@ -375,35 +362,43 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states + ) -> torch.Tensor: + return self.model(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) def pooler( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> Optional[PoolerOutput]: hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) - logits, _ = self.classifier(hidden_states) + logits, _ = self.score(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) + pooled_outputs = [ + self._pooler.build_output(data.squeeze(-1)) for data in pooled_data + ] + return PoolerOutput(outputs=pooled_outputs) - if self.num_labels == 1: - scores = logits.squeeze(-1).to(torch.float32).sigmoid() - else: - batch_scores = torch.nn.functional.log_softmax(logits.to( - torch.float32), - dim=1) - scores = batch_scores[:, 1].exp() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) - pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] - return PoolerOutput(outputs=pooled_outputs) + def maybe_is_qwen3_rerank(self): + if not getattr(self.config, "is_qwen3_rerank", False): + return - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: tokens = getattr(self.config, "classifier_from_token", None) - if tokens is not None: + assert tokens is not None and len(tokens) == 2 + + self.config.num_labels = 1 + model_config = self.vllm_config.model_config + + original_load_weights = self.load_weights + + def load_weights(weights: Iterable[tuple[str, torch.Tensor]]): if get_pp_group().is_last_rank: if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -416,38 +411,21 @@ def load_weights(self, weights: Iterable[tuple[str, else: self.lm_head = PPMissingLayer() - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - - loaded_weights = loader.load_weights(weights) - - if tokens is None: - return loaded_weights - - assert len(tokens) == self.num_labels or (len(tokens) == 2 - and self.num_labels == 1) + loaded_weights = original_load_weights(weights) - from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer( - self.model_config.tokenizer, - revision=self.model_config.tokenizer_revision, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code) + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) - if len(tokens) == 2: a = tokenizer.convert_tokens_to_ids(tokens[0]) b = tokenizer.convert_tokens_to_ids(tokens[1]) - self.classifier.weight.data = self.lm_head.weight.data[ - b] - self.lm_head.weight.data[a] - else: - for i, token in enumerate(tokens): - token_id = tokenizer.convert_tokens_to_ids(token) - self.classifier.weight.data[i] = self.lm_head.weight.data[ - token_id] - - del self.lm_head - loaded_weights.add("classifier.weight") - return loaded_weights + weight = self.lm_head.weight.data[b] - self.lm_head.weight.data[a] + self.score.weight.data.copy_(weight) + + del self.lm_head + loaded_weights.add("classifier.weight") + + self.load_weights = load_weights From 72e1a2eb63b2ecffe80de59c39effacc63d0d280 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 16:34:56 +0800 Subject: [PATCH 13/25] + examples --- examples/offline_inference/qwen3_reranker.py | 46 +++++++++++++++++++ tests/models/language/pooling/test_gte.py | 3 ++ .../language/pooling/test_qwen3_reranker.py | 3 +- 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 examples/offline_inference/qwen3_reranker.py diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py new file mode 100644 index 000000000000..ae719f80fa92 --- /dev/null +++ b/examples/offline_inference/qwen3_reranker.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +from vllm import LLM + +model_name = "Qwen/Qwen3-Reranker-0.6B" + +model = LLM(model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_qwen3_rerank": True, + }, + dtype="float32") + +prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + +instruction = 'Given a web search query, retrieve relevant passages that answer the query' + +queries = [ + "What is the capital of China?", + "Explain gravity", +] + +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +query_template = "{prefix}: {instruction}\n: {query}\n" +document_template = ": {doc}{suffix}" + +queries = [ + query_template.format(prefix=prefix, instruction=instruction, query=query) + for query in queries +] +documents = [ + document_template.format(doc=doc, suffix=suffix) for doc in documents +] + +outputs = model.score(queries, documents) + +print([output.outputs.score for output in outputs]) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 18f2022537a4..cc5e728989f6 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -53,6 +53,9 @@ EmbedModelInfo("Qwen/Qwen3-Embedding-4B", architecture="Qwen3ForCausalLM", enable_test=True), + EmbedModelInfo("Qwen/Qwen3-Embedding-8B", + architecture="Qwen3ForCausalLM", + enable_test=True), ] diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 8f2f73cf4f4f..87b1c9409b0b 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -19,7 +19,8 @@ def vllm_reranker(model_name): "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], "is_qwen3_rerank": True, - }, dtype="float32") + }, + dtype="float32") text_1 = "What is the capital of France?" texts_2 = [ From 0096e890257491ca0b2b9ee0110342041a680b26 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 16:59:48 +0800 Subject: [PATCH 14/25] fix --- examples/offline_inference/qwen3_reranker.py | 14 ++++++++++++-- vllm/model_executor/models/qwen3.py | 4 +++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index ae719f80fa92..783400781c47 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -12,8 +12,18 @@ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], "is_qwen3_rerank": True, - }, - dtype="float32") + }) + +# Why do we need hf_overrides: +# - **Qwen3ForSequenceClassification**, Qwen3 Embedding & Reranker both +# use the same architecture Qwen3ForCausalLM. We need to manually route +# Reranker to Qwen3ForSequenceClassification. +# - **classifier_from_token**, A more efficient approach is to extract +# token_false_id = 2152 and token_true_id = 9693 into a 2-class +# classification task rather than the current 151669-class classification task. +# - **is_qwen3_rerank**, We need to convert the 2-way classifier into a +# 1-way head classifier. This way, it will be completely consistent with +# the Qwen3ForSequenceClassification format. prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 1f51b5fc2c8a..952f5b4c9a5b 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -391,7 +391,9 @@ def maybe_is_qwen3_rerank(self): return tokens = getattr(self.config, "classifier_from_token", None) - assert tokens is not None and len(tokens) == 2 + assert tokens is not None and len(tokens) == 2, \ + ("Load Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") self.config.num_labels = 1 model_config = self.vllm_config.model_config From 0c5d8b05206023e018a5c7430d5078daa4ba4f20 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 17:28:36 +0800 Subject: [PATCH 15/25] is_qwen3_rerank -> is_qwen3_reranker --- examples/offline_inference/qwen3_reranker.py | 4 ++-- tests/models/language/pooling/test_qwen3_reranker.py | 2 +- vllm/model_executor/models/qwen3.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 783400781c47..896bbe666ef4 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -11,7 +11,7 @@ hf_overrides={ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], - "is_qwen3_rerank": True, + "is_qwen3_reranker": True, }) # Why do we need hf_overrides: @@ -21,7 +21,7 @@ # - **classifier_from_token**, A more efficient approach is to extract # token_false_id = 2152 and token_true_id = 9693 into a 2-class # classification task rather than the current 151669-class classification task. -# - **is_qwen3_rerank**, We need to convert the 2-way classifier into a +# - **is_qwen3_reranker**, We need to convert the 2-way classifier into a # 1-way head classifier. This way, it will be completely consistent with # the Qwen3ForSequenceClassification format. diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 87b1c9409b0b..2a669b3a5524 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -18,7 +18,7 @@ def vllm_reranker(model_name): hf_overrides={ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], - "is_qwen3_rerank": True, + "is_qwen3_reranker": True, }, dtype="float32") diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 952f5b4c9a5b..04b3747ba2e0 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -339,7 +339,7 @@ def __init__( self.vllm_config = vllm_config self.config = config - self.maybe_is_qwen3_rerank() + self.maybe_is_qwen3_reranker() self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -386,8 +386,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def maybe_is_qwen3_rerank(self): - if not getattr(self.config, "is_qwen3_rerank", False): + def maybe_is_qwen3_reranker(self): + if not getattr(self.config, "is_qwen3_reranker", False): return tokens = getattr(self.config, "classifier_from_token", None) From f0e3f13bb17628b790e76412378c1c88ad592adc Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 17:49:03 +0800 Subject: [PATCH 16/25] update comments --- examples/offline_inference/qwen3_reranker.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 896bbe666ef4..969c5c1266e2 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -12,18 +12,20 @@ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], "is_qwen3_reranker": True, - }) + }, + dtype="float32") # Why do we need hf_overrides: -# - **Qwen3ForSequenceClassification**, Qwen3 Embedding & Reranker both -# use the same architecture Qwen3ForCausalLM. We need to manually route -# Reranker to Qwen3ForSequenceClassification. -# - **classifier_from_token**, A more efficient approach is to extract -# token_false_id = 2152 and token_true_id = 9693 into a 2-class -# classification task rather than the current 151669-class classification task. -# - **is_qwen3_reranker**, We need to convert the 2-way classifier into a -# 1-way head classifier. This way, it will be completely consistent with -# the Qwen3ForSequenceClassification format. +# Qwen3-Reranker is a language model that doing reranker by using the +# logits of "no" and "yes" tokens. +# vllm converts it to Qwen3ForSequenceClassification when loaded for +# better performance. +# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],` +# to manually route to Qwen3ForSequenceClassification. +# - Then, we will extract the vector corresponding to classifier_from_token +# from lm_head using `"classifier_from_token": ["no", "yes"]`. +# - Third, we will convert these two vectors into one vector. The use of +# conversion logic is controlled by `using "is_qwen3_reranker": True`. prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" From 65a7292a8fe7fdc588dc5b054e20a1da5c0fb678 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 18:21:11 +0800 Subject: [PATCH 17/25] ruff format --- examples/offline_inference/qwen3_reranker.py | 28 +++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 969c5c1266e2..8d4c0124280a 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -6,14 +6,16 @@ model_name = "Qwen/Qwen3-Reranker-0.6B" -model = LLM(model=model_name, - task="score", - hf_overrides={ - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_qwen3_reranker": True, - }, - dtype="float32") +model = LLM( + model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_qwen3_reranker": True, + }, + dtype="float32", +) # Why do we need hf_overrides: # Qwen3-Reranker is a language model that doing reranker by using the @@ -27,10 +29,12 @@ # - Third, we will convert these two vectors into one vector. The use of # conversion logic is controlled by `using "is_qwen3_reranker": True`. -prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" -instruction = 'Given a web search query, retrieve relevant passages that answer the query' +instruction = ( + "Given a web search query, retrieve relevant passages that answer the query" +) queries = [ "What is the capital of China?", @@ -49,9 +53,7 @@ query_template.format(prefix=prefix, instruction=instruction, query=query) for query in queries ] -documents = [ - document_template.format(doc=doc, suffix=suffix) for doc in documents -] +documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] outputs = model.score(queries, documents) From 775c147d2f68a9893597584245249ffb209d8e0a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 18:47:37 +0800 Subject: [PATCH 18/25] is_qwen3_reranker -> is_original_qwen3_reranker --- examples/offline_inference/qwen3_reranker.py | 8 +- .../language/pooling/test_qwen3_reranker.py | 2 +- vllm/model_executor/models/qwen3.py | 89 +++++++++++-------- 3 files changed, 56 insertions(+), 43 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 8d4c0124280a..e9207ec1ec51 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -12,7 +12,7 @@ hf_overrides={ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], - "is_qwen3_reranker": True, + "is_original_qwen3_reranker": True, }, dtype="float32", ) @@ -27,7 +27,7 @@ # - Then, we will extract the vector corresponding to classifier_from_token # from lm_head using `"classifier_from_token": ["no", "yes"]`. # - Third, we will convert these two vectors into one vector. The use of -# conversion logic is controlled by `using "is_qwen3_reranker": True`. +# conversion logic is controlled by `using "is_original_qwen3_reranker": True`. prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" @@ -53,7 +53,9 @@ query_template.format(prefix=prefix, instruction=instruction, query=query) for query in queries ] -documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] +documents = [ + document_template.format(doc=doc, suffix=suffix) for doc in documents +] outputs = model.score(queries, documents) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 2a669b3a5524..63b37d9a077d 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -18,7 +18,7 @@ def vllm_reranker(model_name): hf_overrides={ "architectures": ["Qwen3ForSequenceClassification"], "classifier_from_token": ["no", "yes"], - "is_qwen3_reranker": True, + "is_original_qwen3_reranker": True, }, dtype="float32") diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 04b3747ba2e0..f35e935ef8d0 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -339,8 +339,8 @@ def __init__( self.vllm_config = vllm_config self.config = config - self.maybe_is_qwen3_reranker() - + self.quant_config = quant_config + self.prefix = prefix self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.score = RowParallelLinear(config.hidden_size, @@ -383,51 +383,62 @@ def pooler( return PoolerOutput(outputs=pooled_outputs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + is_original_qwen3_reranker = getattr(self.config, + "is_original_qwen3_reranker", + False) + + if not is_original_qwen3_reranker: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) - def maybe_is_qwen3_reranker(self): - if not getattr(self.config, "is_qwen3_reranker", False): - return + return self.load_weights_from_original_qwen3_reranker(weights) + def load_weights_from_original_qwen3_reranker( + self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) assert tokens is not None and len(tokens) == 2, \ - ("Load Qwen3 Reranker?, see: " + ("Try loading the original Qwen3 Reranker?, see: " "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") self.config.num_labels = 1 model_config = self.vllm_config.model_config - original_load_weights = self.load_weights - - def load_weights(weights: Iterable[tuple[str, torch.Tensor]]): - if get_pp_group().is_last_rank: - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(self.prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - loaded_weights = original_load_weights(weights) - - from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer( - model_config.tokenizer, - revision=model_config.tokenizer_revision, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) - - a = tokenizer.convert_tokens_to_ids(tokens[0]) - b = tokenizer.convert_tokens_to_ids(tokens[1]) - weight = self.lm_head.weight.data[b] - self.lm_head.weight.data[a] - self.score.weight.data.copy_(weight) + device = self.score.weight.device + self.score = RowParallelLinear(self.config.hidden_size, + self.config.num_labels, + quant_config=self.quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + self.prefix, "score")).to(device) - del self.lm_head - loaded_weights.add("classifier.weight") + if get_pp_group().is_last_rank: + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix( + self.prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() - self.load_weights = load_weights + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + a = tokenizer.convert_tokens_to_ids(tokens[0]) + b = tokenizer.convert_tokens_to_ids(tokens[1]) + weight = self.lm_head.weight.data[b].to( + device) - self.lm_head.weight.data[a].to(device) + self.score.weight.data.copy_(weight) + + del self.lm_head + loaded_weights.add("classifier.weight") From 2e88ff2dcafaba526820ec752596b6d0328dff49 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 9 Jun 2025 18:57:33 +0800 Subject: [PATCH 19/25] fix --- examples/offline_inference/qwen3_reranker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index e9207ec1ec51..b1cd1b9874e3 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -53,9 +53,7 @@ query_template.format(prefix=prefix, instruction=instruction, query=query) for query in queries ] -documents = [ - document_template.format(doc=doc, suffix=suffix) for doc in documents -] +documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] outputs = model.score(queries, documents) From 628fa8ccb1faf0fea1691475d6673fb9b5e94195 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 12:47:27 +0800 Subject: [PATCH 20/25] improve comments --- examples/offline_inference/qwen3_reranker.py | 65 ++++++++++++-------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index b1cd1b9874e3..27c4071bf094 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -6,6 +6,22 @@ model_name = "Qwen/Qwen3-Reranker-0.6B" +# What is the difference between the official original version and one +# that has been converted into a sequence classification model? +# Qwen3-Reranker is a language model that doing reranker by using the +# logits of "no" and "yes" tokens. +# It needs to computing 151669 tokens logits, making this method extremely +# inefficient, not to mention incompatible with the vllm score API. +# A method for converting the original model into a sequence classification +# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 +# Models converted offline using this method can not only be more efficient +# and support the vllm score API, but also make the init parameters more +# concise, for example. +# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") + +# If you want to load the official original version, the init parameters are +# as follows. + model = LLM( model=model_name, task="score", @@ -14,12 +30,9 @@ "classifier_from_token": ["no", "yes"], "is_original_qwen3_reranker": True, }, - dtype="float32", ) -# Why do we need hf_overrides: -# Qwen3-Reranker is a language model that doing reranker by using the -# logits of "no" and "yes" tokens. +# Why do we need hf_overrides for the official original version: # vllm converts it to Qwen3ForSequenceClassification when loaded for # better performance. # - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],` @@ -29,32 +42,36 @@ # - Third, we will convert these two vectors into one vector. The use of # conversion logic is controlled by `using "is_original_qwen3_reranker": True`. +# Please use the query_template and document_template to format the query and +# document for better reranker results. + prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" -instruction = ( - "Given a web search query, retrieve relevant passages that answer the query" -) +query_template = "{prefix}: {instruction}\n: {query}\n" +document_template = ": {doc}{suffix}" -queries = [ - "What is the capital of China?", - "Explain gravity", -] +if __name__ == "__main__": + instruction = ( + "Given a web search query, retrieve relevant passages that answer the query" + ) -documents = [ - "The capital of China is Beijing.", - "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", -] + queries = [ + "What is the capital of China?", + "Explain gravity", + ] -query_template = "{prefix}: {instruction}\n: {query}\n" -document_template = ": {doc}{suffix}" + documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", + ] -queries = [ - query_template.format(prefix=prefix, instruction=instruction, query=query) - for query in queries -] -documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] + queries = [ + query_template.format(prefix=prefix, instruction=instruction, query=query) + for query in queries + ] + documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] -outputs = model.score(queries, documents) + outputs = model.score(queries, documents) -print([output.outputs.score for output in outputs]) + print([output.outputs.score for output in outputs]) From d4be7cd325db2dd374ebe5121395865fe9b7242a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 14:44:13 +0800 Subject: [PATCH 21/25] - SupportsPP --- vllm/model_executor/models/qwen3.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index f35e935ef8d0..bad0f6b1ffb7 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -323,7 +323,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) -class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP, +class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, SupportsCrossEncoding): def __init__( @@ -412,17 +412,14 @@ def load_weights_from_original_qwen3_reranker( prefix=maybe_prefix( self.prefix, "score")).to(device) - if get_pp_group().is_last_rank: - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix( - self.prefix, "lm_head")) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens else: - self.lm_head = PPMissingLayer() + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix( + self.prefix, "lm_head")) loader = AutoWeightsLoader(self) loaded_weights = loader.load_weights(weights) @@ -442,3 +439,4 @@ def load_weights_from_original_qwen3_reranker( del self.lm_head loaded_weights.add("classifier.weight") + loaded_weights.discard("lm_head.weight") From fb7f05cc839bf98c60ee585f911e9cd9b63dfb94 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 15:14:19 +0800 Subject: [PATCH 22/25] update docs --- docs/models/supported_models.md | 42 ++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a8a6f3417e54..9cf04ab1ca28 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -387,18 +387,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling Specified using `--task embed`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | -|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------| -| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | -| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | -| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | -| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | -| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | -| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | -| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | -| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------| +| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | +| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | +| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | +| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | +| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | +| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | +| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. @@ -450,12 +451,19 @@ If your model is not in the above list, we will try to automatically convert the Specified using `--task score`. -| Architecture | Models | Example HF Models | -|---------------------------------------|-------------------|----------------------------------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | +| Architecture | Models | Example HF Models | +|---------------------------------------|-------------------|--------------------------------------------------------------------------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | +| `Qwen3ForSequenceClassification` | Qwen3-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | +| `RobertaForSequenceClassification` | RoBERTa-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | +!!! note + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py. + + ```bash + vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' + ``` [](){ #supported-mm-models } ## List of Multimodal Language Models From 6203985e7181b33c26449a88302f735a59f1aede Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 15:16:10 +0800 Subject: [PATCH 23/25] fix --- docs/models/supported_models.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9cf04ab1ca28..a54669b2f5f3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -454,8 +454,8 @@ Specified using `--task score`. | Architecture | Models | Example HF Models | |---------------------------------------|-------------------|--------------------------------------------------------------------------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | -| `Qwen3ForSequenceClassification` | Qwen3-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | -| `RobertaForSequenceClassification` | RoBERTa-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | +| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | !!! note From 2285fab4f7883559b8935df9483d3f2c0b324380 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 15:30:21 +0800 Subject: [PATCH 24/25] fix --- docs/models/supported_models.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a54669b2f5f3..a66199a53976 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -459,8 +459,8 @@ Specified using `--task score`. | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | !!! note - Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py. - + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: . + ```bash vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' ``` From 12f13f5b8f832050d30120ea7ca3a48a7d561bd5 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 10 Jun 2025 18:43:17 +0800 Subject: [PATCH 25/25] fix --- tests/models/language/pooling/test_gte.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index cc5e728989f6..6a3a0f150b6d 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -52,10 +52,8 @@ enable_test=True), EmbedModelInfo("Qwen/Qwen3-Embedding-4B", architecture="Qwen3ForCausalLM", - enable_test=True), - EmbedModelInfo("Qwen/Qwen3-Embedding-8B", - architecture="Qwen3ForCausalLM", - enable_test=True), + dtype="float32", + enable_test=False), ]