-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Bugfix] Fix tensor parallel issue in Qwen3 reranker weight loading #20682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
e9d8f73
fix: Correctly calculate tensor parallel dimension
yurhett b0116ab
use weight loader and add tp test
Isotr0py 90c34b6
mteb atol
Isotr0py 18039a5
Merge remote-tracking branch 'upstream/main'
Isotr0py 7823e59
Merge remote-tracking branch 'upstream/main'
Isotr0py b260e69
increase atol
Isotr0py File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -352,6 +352,24 @@ | |
weight = model.lm_head.weight.data[true_id].to(device).to( | ||
torch.float32) - model.lm_head.weight.data[false_id].to(device).to( | ||
torch.float32) | ||
|
||
# Handle tensor parallel: shard the weight vector if needed | ||
from vllm.distributed import (get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size) | ||
|
||
tp_rank = get_tensor_model_parallel_rank() | ||
tp_size = get_tensor_model_parallel_world_size() | ||
|
||
if tp_size > 1: | ||
# The score layer uses RowParallelLinear where the input dimension is sharded | ||
# Score weight shape: (num_labels, hidden_size // tp_size) | ||
# We need to shard the weight vector along the hidden dimension | ||
assert weight.shape[0] % tp_size == 0, ( | ||
f"Hidden size {weight.shape[0]} must be divisible by tensor parallel size {tp_size}") | ||
shard_size = weight.shape[0] // tp_size | ||
start_idx = tp_rank * shard_size | ||
weight = weight[start_idx:start_idx + shard_size] | ||
|
||
model.score.weight.data.copy_(weight) | ||
|
||
del model.lm_head | ||
|
@@ -392,6 +410,24 @@ | |
|
||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] | ||
score_weight = model.lm_head.weight.data[token_ids].to(device) | ||
|
||
# Handle tensor parallelism: shard the weight matrix if needed | ||
from vllm.distributed import (get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size) | ||
|
||
tp_rank = get_tensor_model_parallel_rank() | ||
tp_size = get_tensor_model_parallel_world_size() | ||
|
||
if tp_size > 1: | ||
Isotr0py marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The score layer uses RowParallelLinear where the input dimension is sharded | ||
# Score weight shape: (num_labels, hidden_size // tp_size) | ||
# We need to shard the weight matrix along the hidden dimension (last dim) | ||
assert score_weight.shape[-1] % tp_size == 0, ( | ||
f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider raising a more descriptive error message that includes the actual hidden size and tensor parallel size values for easier debugging. assert score_weight.shape[-1] % tp_size == 0, (
f"Hidden size {score_weight.shape[-1]} must be divisible by tensor parallel size {tp_size}"
f"Got hidden_size={score_weight.shape[-1]} and tp_size={tp_size}") |
||
shard_size = score_weight.shape[-1] // tp_size | ||
start_idx = tp_rank * shard_size | ||
score_weight = score_weight[:, start_idx:start_idx + shard_size] | ||
|
||
model.score.weight.data.copy_(score_weight) | ||
|
||
del model.lm_head | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider raising a more descriptive error message that includes the actual hidden size and tensor parallel size values for easier debugging.