From e9d8f732d3a2f9fa85dcf41f9e0a06bb4ed07dfc Mon Sep 17 00:00:00 2001 From: yurhett <46419702+yurhett@users.noreply.github.com> Date: Wed, 9 Jul 2025 20:11:05 +0800 Subject: [PATCH 1/4] fix: Correctly calculate tensor parallel dimension --- vllm/model_executor/models/adapters.py | 36 ++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 6584c84436c..71d335dbb4b 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -352,6 +352,24 @@ def load_weights_using_from_2_way_softmax( weight = model.lm_head.weight.data[true_id].to(device).to( torch.float32) - model.lm_head.weight.data[false_id].to(device).to( torch.float32) + + # Handle tensor parallel: shard the weight vector if needed + from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + # The score layer uses RowParallelLinear where the input dimension is sharded + # Score weight shape: (num_labels, hidden_size // tp_size) + # We need to shard the weight vector along the hidden dimension + assert weight.shape[0] % tp_size == 0, ( + f"Hidden size {weight.shape[0]} must be divisible by tensor parallel size {tp_size}") + shard_size = weight.shape[0] // tp_size + start_idx = tp_rank * shard_size + weight = weight[start_idx:start_idx + shard_size] + model.score.weight.data.copy_(weight) del model.lm_head @@ -392,6 +410,24 @@ def load_weights_no_post_processing(model, token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] score_weight = model.lm_head.weight.data[token_ids].to(device) + + # Handle tensor parallelism: shard the weight matrix if needed + from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + # The score layer uses RowParallelLinear where the input dimension is sharded + # Score weight shape: (num_labels, hidden_size // tp_size) + # We need to shard the weight matrix along the hidden dimension (last dim) + assert score_weight.shape[-1] % tp_size == 0, ( + f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}") + shard_size = score_weight.shape[-1] // tp_size + start_idx = tp_rank * shard_size + score_weight = score_weight[:, start_idx:start_idx + shard_size] + model.score.weight.data.copy_(score_weight) del model.lm_head From b0116ab792757d65ca75cb7b52c61080207cd8c6 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 10 Jul 2025 17:03:19 +0800 Subject: [PATCH 2/4] use weight loader and add tp test Signed-off-by: Isotr0py <2037008807@qq.com> --- .../language/pooling/test_qwen3_reranker.py | 24 +++++++++ vllm/model_executor/models/adapters.py | 49 +++---------------- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 9f040639c78..c38223f42cd 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -6,6 +6,7 @@ import torch from tests.conftest import HfRunner +from tests.utils import multi_gpu_test from .mteb_utils import RerankModelInfo, mteb_test_rerank_models @@ -87,3 +88,26 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +@multi_gpu_test(num_gpus=2) +def test_rerank_models_mteb_tp(vllm_runner, + model_info: RerankModelInfo) -> None: + + assert model_info.architecture == "Qwen3ForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + }, + "tensor_parallel_size": 2, + } + + if model_info.name == "Qwen/Qwen3-Reranker-4B": + vllm_extra_kwargs["max_num_seqs"] = 1 + + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 71d335dbb4b..dcdf69f773a 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -322,6 +322,8 @@ def load_weights_using_from_2_way_softmax( # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -329,8 +331,6 @@ def load_weights_using_from_2_way_softmax( tokens = cast(list[int], tokens) assert len(tokens) == 2 - device = model.score.weight.device - if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: @@ -349,28 +349,13 @@ def load_weights_using_from_2_way_softmax( false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) - weight = model.lm_head.weight.data[true_id].to(device).to( - torch.float32) - model.lm_head.weight.data[false_id].to(device).to( + weight = model.lm_head.weight.data[[true_id]].to( + torch.float32) - model.lm_head.weight.data[[false_id]].to( torch.float32) - - # Handle tensor parallel: shard the weight vector if needed - from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - # The score layer uses RowParallelLinear where the input dimension is sharded - # Score weight shape: (num_labels, hidden_size // tp_size) - # We need to shard the weight vector along the hidden dimension - assert weight.shape[0] % tp_size == 0, ( - f"Hidden size {weight.shape[0]} must be divisible by tensor parallel size {tp_size}") - shard_size = weight.shape[0] // tp_size - start_idx = tp_rank * shard_size - weight = weight[start_idx:start_idx + shard_size] - - model.score.weight.data.copy_(weight) + + param = model.score.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight) del model.lm_head loaded_weights.add("score.weight") @@ -410,24 +395,6 @@ def load_weights_no_post_processing(model, token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] score_weight = model.lm_head.weight.data[token_ids].to(device) - - # Handle tensor parallelism: shard the weight matrix if needed - from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - # The score layer uses RowParallelLinear where the input dimension is sharded - # Score weight shape: (num_labels, hidden_size // tp_size) - # We need to shard the weight matrix along the hidden dimension (last dim) - assert score_weight.shape[-1] % tp_size == 0, ( - f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}") - shard_size = score_weight.shape[-1] // tp_size - start_idx = tp_rank * shard_size - score_weight = score_weight[:, start_idx:start_idx + shard_size] - model.score.weight.data.copy_(score_weight) del model.lm_head From 90c34b6839d78a2a568c1fb8cecc2b609ab6954b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 11 Jul 2025 11:21:02 +0800 Subject: [PATCH 3/4] mteb atol Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/language/pooling/mteb_utils.py | 5 +++-- tests/models/language/pooling/test_qwen3_reranker.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 847ea5f623f..6c4fde5fdfa 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -268,7 +268,8 @@ def mteb_test_rerank_models(hf_runner, model_info: RerankModelInfo, vllm_extra_kwargs=None, hf_model_callback=None, - vllm_mteb_encoder=VllmMtebEncoder): + vllm_mteb_encoder=VllmMtebEncoder, + atol=MTEB_RERANK_TOL): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -301,4 +302,4 @@ def mteb_test_rerank_models(hf_runner, print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + assert st_main_score == pytest.approx(vllm_main_score, abs=atol) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index c38223f42cd..e3a0f35b83c 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -109,5 +109,8 @@ def test_rerank_models_mteb_tp(vllm_runner, if model_info.name == "Qwen/Qwen3-Reranker-4B": vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + atol=1e-2) From b260e69ef30c9e6bf3a9b594b9725633b3aa51e9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 11 Jul 2025 22:20:56 +0800 Subject: [PATCH 4/4] increase atol Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/language/pooling/test_qwen3_reranker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index e3a0f35b83c..9c6a833b413 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -113,4 +113,4 @@ def test_rerank_models_mteb_tp(vllm_runner, vllm_runner, model_info, vllm_extra_kwargs, - atol=1e-2) + atol=1.2e-2)