Skip to content

Support Qwen3 Embedding & Reranker on Gaudi for aice/v1.21.0 branch #1456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: aice/v1.21.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2e555e5
Optimized Qwen3 and Qwen3-MoE on Gaudi, supporting BF16 and inc based…
gyou2021 May 27, 2025
63716b7
Optimized Qwen3 and Qwen3-MoE on Gaudi, supporting BF16 and INC based…
gyou2021 May 27, 2025
95d8fd2
merged upstream aice/v1.21.0
gyou2021 Jun 11, 2025
9b22301
added files for INC based Qwen3-235B-A22B FP8.
gyou2021 Jun 11, 2025
8d193e1
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v…
gyou2021 Jun 13, 2025
834822d
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v…
gyou2021 Jun 13, 2025
07575ea
Supported EP of INC FP8 quant of Qwen3-235B-A22B.
gyou2021 Jun 13, 2025
522b46e
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v…
gyou2021 Jun 16, 2025
5f02426
Optimized merge_multimodal_embeddings on Gaudi
gyou2021 Jun 16, 2025
54b9ca3
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v…
gyou2021 Jun 19, 2025
abd2582
[New Model]: Support Qwen3 Embedding & Reranker (#19260)
noooop Jun 11, 2025
b98ec97
Removed assert to support hpu lazy mode of roberta models
gyou2021 Jun 20, 2025
c4b7aa2
Fixed the bug and supported qwen3 reranker on Gaudi.
gyou2021 Jun 20, 2025
610672b
Fixed the bug in the qwen3 reranker score.
gyou2021 Jun 21, 2025
626c905
Transform input data to fp32 for PoolerHead.
gyou2021 Jun 23, 2025
ab64e2a
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v…
gyou2021 Jun 23, 2025
138df41
rebased utils.py
gyou2021 Jun 23, 2025
a9fe6a7
rebased utils.py
gyou2021 Jun 23, 2025
dfa56f6
rebased utils.py
gyou2021 Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
700 changes: 700 additions & 0 deletions docs/models/supported_models.md

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions examples/offline_inference/qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

from vllm import LLM

model_name = "Qwen/Qwen3-Reranker-0.6B"

# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")

# If you want to load the official original version, the init parameters are
# as follows.

model = LLM(
model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)

# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"

if __name__ == "__main__":
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)

queries = [
"What is the capital of China?",
"Explain gravity",
]

documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]

outputs = model.score(queries, documents)

print([output.outputs.score for output in outputs])
82 changes: 82 additions & 0 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import pytest

from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models

MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
########## Qwen3ForCausalLM
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=True),
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=False),
]


@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)


@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo,
example_prompts) -> None:

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)
87 changes: 87 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "Qwen/Qwen3-Reranker-4B"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
dtype="float32")

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]

outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
73 changes: 73 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name, task="score")
outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
4 changes: 2 additions & 2 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
Expand Down
18 changes: 16 additions & 2 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def __init__(self, *, normalize: bool, softmax: bool) -> None:
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):

# Using float32 in PoolerHead
if isinstance(pooled_data, list):
for i in range(len(pooled_data)):
pooled_data[i] = pooled_data[i].to(torch.float32)
else:
pooled_data = pooled_data.to(torch.float32)

dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
Expand All @@ -261,9 +268,16 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],

if self.softmax:
if isinstance(pooled_data, list):
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
pooled_data = [
F.softmax(data, dim=-1)
if data.shape[-1] >= 2 else F.sigmoid(data)
for data in pooled_data
]
else:
pooled_data = F.softmax(pooled_data, dim=-1)
if pooled_data.shape[-1] >= 2:
pooled_data = F.softmax(pooled_data, dim=-1)
else:
pooled_data = F.sigmoid(pooled_data)

return pooled_data

Expand Down
Loading
Loading