@@ -452,8 +452,10 @@ def __init__(
452
452
else :
453
453
self .register_parameter ("bias" , None )
454
454
455
+ self .tp_rank = get_tensor_model_parallel_rank ()
456
+
455
457
def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
456
- tp_rank = get_tensor_model_parallel_rank ()
458
+
457
459
output_dim = getattr (param , "output_dim" , None )
458
460
459
461
is_sharded_weight = getattr (param , "is_sharded_weight" , False )
@@ -472,15 +474,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
472
474
if is_gguf_weight and isinstance (param , UninitializedParameter ):
473
475
final_shape = list (loaded_weight .shape )
474
476
if output_dim is not None :
475
- tp_size = get_tensor_model_parallel_world_size ()
476
- assert final_shape [output_dim ] % tp_size == 0
477
- final_shape [ output_dim ] = final_shape [ output_dim ] // tp_size
477
+ assert final_shape [ output_dim ] % self . tp_size == 0
478
+ final_shape [output_dim ] = ( final_shape [ output_dim ] //
479
+ self . tp_size )
478
480
param .materialize (final_shape , dtype = loaded_weight .dtype )
479
481
480
482
param_data = param .data
481
483
if output_dim is not None and not is_sharded_weight :
482
484
shard_size = param_data .shape [output_dim ]
483
- start_idx = tp_rank * shard_size
485
+ start_idx = self . tp_rank * shard_size
484
486
loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
485
487
shard_size )
486
488
@@ -565,8 +567,11 @@ def __init__(
565
567
return_bias : bool = True ,
566
568
):
567
569
self .output_sizes = output_sizes
568
- tp_size = get_tensor_model_parallel_world_size ()
569
- assert all (output_size % tp_size == 0 for output_size in output_sizes )
570
+ self .tp_size = get_tensor_model_parallel_world_size ()
571
+ self .tp_rank = get_tensor_model_parallel_rank ()
572
+
573
+ assert all (output_size % self .tp_size == 0
574
+ for output_size in output_sizes )
570
575
super ().__init__ (input_size = input_size ,
571
576
output_size = sum (output_sizes ),
572
577
bias = bias ,
@@ -598,12 +603,10 @@ def weight_loader(self,
598
603
return
599
604
600
605
if is_gguf_weight :
601
- tp_size = get_tensor_model_parallel_world_size ()
602
- tp_rank = get_tensor_model_parallel_rank ()
603
606
604
607
output_dim = getattr (param , "output_dim" , None )
605
- shard_size = loaded_weight .size (output_dim ) // tp_size
606
- start_idx = tp_rank * shard_size
608
+ shard_size = loaded_weight .size (output_dim ) // self . tp_size
609
+ start_idx = self . tp_rank * shard_size
607
610
608
611
if loaded_shard_id is not None :
609
612
loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
@@ -669,11 +672,10 @@ def weight_loader(self,
669
672
return
670
673
671
674
assert loaded_shard_id < len (self .output_sizes )
672
- tp_rank = get_tensor_model_parallel_rank ()
673
- tp_size = get_tensor_model_parallel_world_size ()
674
675
if output_dim is not None :
675
- shard_offset = sum (self .output_sizes [:loaded_shard_id ]) // tp_size
676
- shard_size = self .output_sizes [loaded_shard_id ] // tp_size
676
+ shard_offset = (sum (self .output_sizes [:loaded_shard_id ]) //
677
+ self .tp_size )
678
+ shard_size = self .output_sizes [loaded_shard_id ] // self .tp_size
677
679
# Special case for quantization.
678
680
# If quantized, we need to adjust the offset and size to account
679
681
# for the packing.
@@ -701,7 +703,7 @@ def weight_loader(self,
701
703
702
704
param_data = param_data .narrow (output_dim , shard_offset ,
703
705
shard_size )
704
- start_idx = tp_rank * shard_size
706
+ start_idx = self . tp_rank * shard_size
705
707
if not is_sharded_weight :
706
708
loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
707
709
shard_size )
@@ -991,12 +993,9 @@ def weight_loader(self,
991
993
return
992
994
993
995
if is_gguf_weight :
994
- tp_size = get_tensor_model_parallel_world_size ()
995
- tp_rank = get_tensor_model_parallel_rank ()
996
-
997
996
output_dim = getattr (param , "output_dim" , None )
998
- shard_size = loaded_weight .size (output_dim ) // tp_size
999
- start_idx = tp_rank * shard_size
997
+ shard_size = loaded_weight .size (output_dim ) // self . tp_size
998
+ start_idx = self . tp_rank * shard_size
1000
999
1001
1000
if loaded_shard_id is not None :
1002
1001
loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
@@ -1071,7 +1070,6 @@ def weight_loader(self,
1071
1070
self .weight_loader (param , loaded_weight_shard , shard_id )
1072
1071
return
1073
1072
1074
- tp_rank = get_tensor_model_parallel_rank ()
1075
1073
assert loaded_shard_id in ["q" , "k" , "v" ]
1076
1074
1077
1075
# If output dim is defined, use the default loading process.
@@ -1123,9 +1121,9 @@ def weight_loader(self,
1123
1121
param_data = param_data .narrow (output_dim , shard_offset ,
1124
1122
shard_size )
1125
1123
if loaded_shard_id == "q" :
1126
- shard_id = tp_rank
1124
+ shard_id = self . tp_rank
1127
1125
else :
1128
- shard_id = tp_rank // self .num_kv_head_replicas
1126
+ shard_id = self . tp_rank // self .num_kv_head_replicas
1129
1127
start_idx = shard_id * shard_size
1130
1128
1131
1129
if not is_sharded_weight :
@@ -1245,8 +1243,6 @@ def __init__(
1245
1243
self .register_parameter ("bias" , None )
1246
1244
1247
1245
def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
1248
- tp_rank = get_tensor_model_parallel_rank ()
1249
- tp_size = get_tensor_model_parallel_world_size ()
1250
1246
input_dim = getattr (param , "input_dim" , None )
1251
1247
use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
1252
1248
is_sharded_weight = getattr (param , "is_sharded_weight" , False )
@@ -1264,13 +1260,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
1264
1260
if is_gguf_weight and isinstance (param , UninitializedParameter ):
1265
1261
weight_shape = list (loaded_weight .shape )
1266
1262
if input_dim :
1267
- weight_shape [input_dim ] = weight_shape [input_dim ] // tp_size
1263
+ weight_shape [input_dim ] = (weight_shape [input_dim ] //
1264
+ self .tp_size )
1268
1265
param .materialize (tuple (weight_shape ), dtype = loaded_weight .dtype )
1269
1266
1270
1267
param_data = param .data
1271
1268
if input_dim is not None and not is_sharded_weight :
1272
1269
shard_size = param_data .shape [input_dim ]
1273
- start_idx = tp_rank * shard_size
1270
+ start_idx = self . tp_rank * shard_size
1274
1271
loaded_weight = loaded_weight .narrow (input_dim , start_idx ,
1275
1272
shard_size )
1276
1273
0 commit comments