Skip to content

Commit c9ba810

Browse files
authored
[Bugfix] weight loading use correct tp_group with patch_tensor_parallel_group (#21024)
Signed-off-by: KevinXiong-C <kevin_xiong1997@outlook.com>
1 parent 4e7dfbe commit c9ba810

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,10 @@ def __init__(
452452
else:
453453
self.register_parameter("bias", None)
454454

455+
self.tp_rank = get_tensor_model_parallel_rank()
456+
455457
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
456-
tp_rank = get_tensor_model_parallel_rank()
458+
457459
output_dim = getattr(param, "output_dim", None)
458460

459461
is_sharded_weight = getattr(param, "is_sharded_weight", False)
@@ -472,15 +474,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
472474
if is_gguf_weight and isinstance(param, UninitializedParameter):
473475
final_shape = list(loaded_weight.shape)
474476
if output_dim is not None:
475-
tp_size = get_tensor_model_parallel_world_size()
476-
assert final_shape[output_dim] % tp_size == 0
477-
final_shape[output_dim] = final_shape[output_dim] // tp_size
477+
assert final_shape[output_dim] % self.tp_size == 0
478+
final_shape[output_dim] = (final_shape[output_dim] //
479+
self.tp_size)
478480
param.materialize(final_shape, dtype=loaded_weight.dtype)
479481

480482
param_data = param.data
481483
if output_dim is not None and not is_sharded_weight:
482484
shard_size = param_data.shape[output_dim]
483-
start_idx = tp_rank * shard_size
485+
start_idx = self.tp_rank * shard_size
484486
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
485487
shard_size)
486488

@@ -565,8 +567,11 @@ def __init__(
565567
return_bias: bool = True,
566568
):
567569
self.output_sizes = output_sizes
568-
tp_size = get_tensor_model_parallel_world_size()
569-
assert all(output_size % tp_size == 0 for output_size in output_sizes)
570+
self.tp_size = get_tensor_model_parallel_world_size()
571+
self.tp_rank = get_tensor_model_parallel_rank()
572+
573+
assert all(output_size % self.tp_size == 0
574+
for output_size in output_sizes)
570575
super().__init__(input_size=input_size,
571576
output_size=sum(output_sizes),
572577
bias=bias,
@@ -598,12 +603,10 @@ def weight_loader(self,
598603
return
599604

600605
if is_gguf_weight:
601-
tp_size = get_tensor_model_parallel_world_size()
602-
tp_rank = get_tensor_model_parallel_rank()
603606

604607
output_dim = getattr(param, "output_dim", None)
605-
shard_size = loaded_weight.size(output_dim) // tp_size
606-
start_idx = tp_rank * shard_size
608+
shard_size = loaded_weight.size(output_dim) // self.tp_size
609+
start_idx = self.tp_rank * shard_size
607610

608611
if loaded_shard_id is not None:
609612
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -669,11 +672,10 @@ def weight_loader(self,
669672
return
670673

671674
assert loaded_shard_id < len(self.output_sizes)
672-
tp_rank = get_tensor_model_parallel_rank()
673-
tp_size = get_tensor_model_parallel_world_size()
674675
if output_dim is not None:
675-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
676-
shard_size = self.output_sizes[loaded_shard_id] // tp_size
676+
shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
677+
self.tp_size)
678+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
677679
# Special case for quantization.
678680
# If quantized, we need to adjust the offset and size to account
679681
# for the packing.
@@ -701,7 +703,7 @@ def weight_loader(self,
701703

702704
param_data = param_data.narrow(output_dim, shard_offset,
703705
shard_size)
704-
start_idx = tp_rank * shard_size
706+
start_idx = self.tp_rank * shard_size
705707
if not is_sharded_weight:
706708
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
707709
shard_size)
@@ -991,12 +993,9 @@ def weight_loader(self,
991993
return
992994

993995
if is_gguf_weight:
994-
tp_size = get_tensor_model_parallel_world_size()
995-
tp_rank = get_tensor_model_parallel_rank()
996-
997996
output_dim = getattr(param, "output_dim", None)
998-
shard_size = loaded_weight.size(output_dim) // tp_size
999-
start_idx = tp_rank * shard_size
997+
shard_size = loaded_weight.size(output_dim) // self.tp_size
998+
start_idx = self.tp_rank * shard_size
1000999

10011000
if loaded_shard_id is not None:
10021001
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -1071,7 +1070,6 @@ def weight_loader(self,
10711070
self.weight_loader(param, loaded_weight_shard, shard_id)
10721071
return
10731072

1074-
tp_rank = get_tensor_model_parallel_rank()
10751073
assert loaded_shard_id in ["q", "k", "v"]
10761074

10771075
# If output dim is defined, use the default loading process.
@@ -1123,9 +1121,9 @@ def weight_loader(self,
11231121
param_data = param_data.narrow(output_dim, shard_offset,
11241122
shard_size)
11251123
if loaded_shard_id == "q":
1126-
shard_id = tp_rank
1124+
shard_id = self.tp_rank
11271125
else:
1128-
shard_id = tp_rank // self.num_kv_head_replicas
1126+
shard_id = self.tp_rank // self.num_kv_head_replicas
11291127
start_idx = shard_id * shard_size
11301128

11311129
if not is_sharded_weight:
@@ -1245,8 +1243,6 @@ def __init__(
12451243
self.register_parameter("bias", None)
12461244

12471245
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
1248-
tp_rank = get_tensor_model_parallel_rank()
1249-
tp_size = get_tensor_model_parallel_world_size()
12501246
input_dim = getattr(param, "input_dim", None)
12511247
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
12521248
is_sharded_weight = getattr(param, "is_sharded_weight", False)
@@ -1264,13 +1260,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
12641260
if is_gguf_weight and isinstance(param, UninitializedParameter):
12651261
weight_shape = list(loaded_weight.shape)
12661262
if input_dim:
1267-
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
1263+
weight_shape[input_dim] = (weight_shape[input_dim] //
1264+
self.tp_size)
12681265
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
12691266

12701267
param_data = param.data
12711268
if input_dim is not None and not is_sharded_weight:
12721269
shard_size = param_data.shape[input_dim]
1273-
start_idx = tp_rank * shard_size
1270+
start_idx = self.tp_rank * shard_size
12741271
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
12751272
shard_size)
12761273

0 commit comments

Comments
 (0)