Skip to content

Commit 11c0198

Browse files
yurhettIsotr0py
andauthored
[Bugfix] Fix tensor parallel issue in Qwen3 reranker weight loading (#20682)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
1 parent b1235c3 commit 11c0198

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

tests/models/language/pooling/mteb_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def mteb_test_rerank_models(hf_runner,
268268
model_info: RerankModelInfo,
269269
vllm_extra_kwargs=None,
270270
hf_model_callback=None,
271-
vllm_mteb_encoder=VllmMtebEncoder):
271+
vllm_mteb_encoder=VllmMtebEncoder,
272+
atol=MTEB_RERANK_TOL):
272273
if not model_info.enable_test:
273274
# A model family has many models with the same architecture,
274275
# and we don't need to test each one.
@@ -301,4 +302,4 @@ def mteb_test_rerank_models(hf_runner,
301302
print("SentenceTransformers:", st_dtype, st_main_score)
302303
print("Difference:", st_main_score - vllm_main_score)
303304

304-
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
305+
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)

tests/models/language/pooling/test_qwen3_reranker.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from tests.conftest import HfRunner
9+
from tests.utils import multi_gpu_test
910

1011
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
1112

@@ -87,3 +88,29 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
8788

8889
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
8990
vllm_extra_kwargs)
91+
92+
93+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
94+
@multi_gpu_test(num_gpus=2)
95+
def test_rerank_models_mteb_tp(vllm_runner,
96+
model_info: RerankModelInfo) -> None:
97+
98+
assert model_info.architecture == "Qwen3ForSequenceClassification"
99+
100+
vllm_extra_kwargs: dict[str, Any] = {
101+
"hf_overrides": {
102+
"architectures": ["Qwen3ForSequenceClassification"],
103+
"classifier_from_token": ["no", "yes"],
104+
"is_original_qwen3_reranker": True,
105+
},
106+
"tensor_parallel_size": 2,
107+
}
108+
109+
if model_info.name == "Qwen/Qwen3-Reranker-4B":
110+
vllm_extra_kwargs["max_num_seqs"] = 1
111+
112+
mteb_test_rerank_models(Qwen3RerankerHfRunner,
113+
vllm_runner,
114+
model_info,
115+
vllm_extra_kwargs,
116+
atol=1.2e-2)

vllm/model_executor/models/adapters.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,15 @@ def load_weights_using_from_2_way_softmax(
322322
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
323323
from vllm.model_executor.layers.vocab_parallel_embedding import (
324324
ParallelLMHead)
325+
from vllm.model_executor.model_loader.weight_utils import (
326+
default_weight_loader)
325327
from vllm.model_executor.models.utils import AutoWeightsLoader
326328

327329
model_config = model.vllm_config.model_config
328330
tokens = getattr(model.config, "classifier_from_token", [])
329331
tokens = cast(list[int], tokens)
330332
assert len(tokens) == 2
331333

332-
device = model.score.weight.device
333-
334334
if model.config.tie_word_embeddings:
335335
model.lm_head = model.model.embed_tokens
336336
else:
@@ -349,10 +349,13 @@ def load_weights_using_from_2_way_softmax(
349349

350350
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
351351
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
352-
weight = model.lm_head.weight.data[true_id].to(device).to(
353-
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
352+
weight = model.lm_head.weight.data[[true_id]].to(
353+
torch.float32) - model.lm_head.weight.data[[false_id]].to(
354354
torch.float32)
355-
model.score.weight.data.copy_(weight)
355+
356+
param = model.score.weight
357+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
358+
weight_loader(param, weight)
356359

357360
del model.lm_head
358361
loaded_weights.add("score.weight")

0 commit comments

Comments
 (0)