Skip to content

[New Model]: Support Qwen3 Embedding & Reranker #19260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 11, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions vllm/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import LastPool
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
Expand Down Expand Up @@ -245,7 +248,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
decoder_layer_type=Qwen3DecoderLayer)


class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsCrossEncoding):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -288,6 +292,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

if vllm_config.model_config.task == "score":
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self._pooler = LastPool(normalize=False, softmax=False)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand All @@ -311,6 +321,30 @@ def compute_logits(
sampling_metadata)
return logits

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
logits = self.logits_processor._get_logits(hidden_states=hidden_states,
lm_head=self.lm_head,
embedding_bias=None)

token_false_id = 2152
token_true_id = 9693

true_vector = logits[:, token_true_id]
false_vector = logits[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)

batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp()

pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
Expand Down