Skip to content

[Bugfix] Fix tensor parallel issue in Qwen3 reranker weight loading #20682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 12, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,24 @@
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
torch.float32)

# Handle tensor parallel: shard the weight vector if needed
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)

tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()

if tp_size > 1:
# The score layer uses RowParallelLinear where the input dimension is sharded

Check failure on line 364 in vllm/model_executor/models/adapters.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/adapters.py:364:81: E501 Line too long (85 > 80)
# Score weight shape: (num_labels, hidden_size // tp_size)
# We need to shard the weight vector along the hidden dimension
assert weight.shape[0] % tp_size == 0, (
f"Hidden size {weight.shape[0]} must be divisible by tensor parallel size {tp_size}")

Check failure on line 368 in vllm/model_executor/models/adapters.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/adapters.py:368:81: E501 Line too long (96 > 80)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider raising a more descriptive error message that includes the actual hidden size and tensor parallel size values for easier debugging.

assert weight.shape[0] % tp_size == 0, (
            f"Hidden size {weight.shape[0]} must be divisible by tensor parallel size {tp_size}."
            f"Got hidden_size={weight.shape[0]} and tp_size={tp_size}")

shard_size = weight.shape[0] // tp_size
start_idx = tp_rank * shard_size
weight = weight[start_idx:start_idx + shard_size]

model.score.weight.data.copy_(weight)

del model.lm_head
Expand Down Expand Up @@ -392,6 +410,24 @@

token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids].to(device)

# Handle tensor parallelism: shard the weight matrix if needed
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)

tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()

if tp_size > 1:
# The score layer uses RowParallelLinear where the input dimension is sharded
# Score weight shape: (num_labels, hidden_size // tp_size)

Check failure on line 423 in vllm/model_executor/models/adapters.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/adapters.py:423:81: E501 Line too long (85 > 80)
# We need to shard the weight matrix along the hidden dimension (last dim)
assert score_weight.shape[-1] % tp_size == 0, (

Check failure on line 425 in vllm/model_executor/models/adapters.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/adapters.py:425:81: E501 Line too long (82 > 80)
f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider raising a more descriptive error message that includes the actual hidden size and tensor parallel size values for easier debugging.

assert score_weight.shape[-1] % tp_size == 0, (
            f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}"
            f"Got hidden_size={score_weight.shape[-1]} and tp_size={tp_size}")

shard_size = score_weight.shape[-1] // tp_size

Check failure on line 427 in vllm/model_executor/models/adapters.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/adapters.py:427:81: E501 Line too long (103 > 80)
start_idx = tp_rank * shard_size
score_weight = score_weight[:, start_idx:start_idx + shard_size]

model.score.weight.data.copy_(score_weight)

del model.lm_head
Expand Down
Loading