Skip to content

[New Model]: Support Qwen3 Embedding & Reranker #19260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
########## Qwen3ForCausalLM
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=True),
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
architecture="Qwen3ForCausalLM",
enable_test=True),
]


Expand Down
86 changes: 86 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "Qwen/Qwen3-Reranker-4B"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"]
},
dtype="float32")

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]

outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-Reranker-0.6B", # noqa: E501
hf_overrides={"architectures": ["Qwen3ForSequenceClassification"]}), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
v0_only=True),
Expand Down
136 changes: 134 additions & 2 deletions vllm/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import LastPool
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
Expand Down Expand Up @@ -319,3 +322,132 @@ def load_weights(self, weights: Iterable[tuple[str,
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP,
SupportsCrossEncoding):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.config = config
self.model_config = vllm_config.model_config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.prefix = prefix
self.num_labels = getattr(config, "num_labels", 2)
if self.num_labels == 2:
self.num_labels = 1

self.classifier = RowParallelLinear(config.hidden_size,
self.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "classifier"))
self._pooler = LastPool(normalize=False, softmax=False)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
logits, _ = self.classifier(hidden_states)

if self.num_labels == 1:
scores = logits.squeeze(-1).to(torch.float32).sigmoid()
else:
batch_scores = torch.nn.functional.log_softmax(logits.to(
torch.float32),
dim=1)
scores = batch_scores[:, 1].exp()

pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
tokens = getattr(self.config, "classifier_from_token", None)
if tokens is not None:
if get_pp_group().is_last_rank:
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(self.prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()

loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)

loaded_weights = loader.load_weights(weights)

if tokens is None:
return loaded_weights

assert len(tokens) == self.num_labels or (len(tokens) == 2
and self.num_labels == 1)

from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
self.model_config.tokenizer,
revision=self.model_config.tokenizer_revision,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code)

if len(tokens) == 2:
a = tokenizer.convert_tokens_to_ids(tokens[0])
b = tokenizer.convert_tokens_to_ids(tokens[1])
self.classifier.weight.data = self.lm_head.weight.data[
b] - self.lm_head.weight.data[a]
else:
for i, token in enumerate(tokens):
token_id = tokenizer.convert_tokens_to_ids(token)
self.classifier.weight.data[i] = self.lm_head.weight.data[
token_id]

del self.lm_head
loaded_weights.add("classifier.weight")
return loaded_weights
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
}

_MULTIMODAL_MODELS = {
Expand Down