22
22
from vllm .model_executor .layers .quantization import QuantizationConfig
23
23
from vllm .model_executor .layers .vocab_parallel_embedding import (
24
24
VocabParallelEmbedding )
25
- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
26
25
from vllm .model_executor .pooling_metadata import PoolingMetadata
27
26
from vllm .sequence import IntermediateTensors , PoolerOutput
28
27
29
28
from .interfaces import SupportsCrossEncoding , SupportsQuant , SupportsV0Only
30
- from .utils import WeightsMapper , maybe_prefix
29
+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
31
30
32
31
33
32
class BertEmbedding (nn .Module ):
@@ -44,9 +43,11 @@ def __init__(self, config: BertConfig):
44
43
config .type_vocab_size , config .hidden_size )
45
44
self .LayerNorm = nn .LayerNorm (config .hidden_size ,
46
45
eps = config .layer_norm_eps )
47
- self .position_ids = nn .Parameter (
48
- torch .empty ((1 , config .max_position_embeddings )), )
49
46
47
+ self .register_buffer (
48
+ "position_ids" ,
49
+ torch .arange (config .max_position_embeddings ).unsqueeze (0 ),
50
+ )
50
51
self .position_embedding_type = config .position_embedding_type
51
52
if self .position_embedding_type != "absolute" :
52
53
raise ValueError ("Only 'absolute' position_embedding_type" +
@@ -358,45 +359,45 @@ def load_weights(self, weights: Iterable[tuple[str,
358
359
("qkv_proj" , "value" , "v" ),
359
360
]
360
361
362
+ loaded_stacked_params = []
363
+ other_weights = []
361
364
params_dict = dict (self .named_parameters ())
362
- loaded_params : set [str ] = set ()
363
365
for name , loaded_weight in weights :
364
- if self .pooler is None and "pooler" in name :
365
- continue
366
366
for (param_name , weight_name , shard_id ) in stacked_params_mapping :
367
367
if weight_name not in name :
368
368
continue
369
+
369
370
name = name .replace (weight_name , param_name )
370
- # Skip loading extra bias for GPTQ models.
371
- if name .endswith (".bias" ) and name not in params_dict :
371
+ if name not in params_dict :
372
372
continue
373
373
param = params_dict [name ]
374
374
weight_loader = param .weight_loader
375
375
weight_loader (param , loaded_weight , shard_id )
376
+ loaded_stacked_params .append (name )
376
377
break
377
378
else :
378
- # Skip loading extra bias for GPTQ models.
379
- if name .endswith (".bias" ) and name not in params_dict :
380
- continue
381
- param = params_dict [name ]
382
- weight_loader = getattr (param , "weight_loader" ,
383
- default_weight_loader )
384
- weight_loader (param , loaded_weight )
385
- loaded_params .add (name )
379
+ if name in params_dict :
380
+ other_weights .append ((name , loaded_weight ))
381
+
382
+ loader = AutoWeightsLoader (
383
+ self ,
384
+ skip_prefixes = (["pooler." ] if self .pooler is None else []),
385
+ )
386
+ loaded_params = loader .load_weights (other_weights )
387
+ loaded_params .update (loaded_stacked_params )
386
388
return loaded_params
387
389
388
390
389
391
class BertEmbeddingModel (nn .Module , SupportsV0Only , SupportsQuant ):
390
392
"""A model that uses Bert to provide embedding functionalities.
391
393
392
- This class encapsulates the BertModel and provides an interface for
393
- embedding operations and customized pooling functions.
394
+ This class encapsulates the BertModel and provides an interface for
395
+ embedding operations and customized pooling functions.
394
396
395
- Attributes:
396
- model: An instance of BertModel used for forward operations.
397
- _pooler: An instance of Pooler used for pooling operations.
398
- """
399
- hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {"model." : "" })
397
+ Attributes:
398
+ model: An instance of BertModel used for forward operations.
399
+ _pooler: An instance of Pooler used for pooling operations.
400
+ """
400
401
401
402
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
402
403
super ().__init__ ()
@@ -425,10 +426,15 @@ def pooler(
425
426
return self ._pooler (hidden_states , pooling_metadata )
426
427
427
428
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
428
- weights = self .hf_to_vllm_mapper .apply (weights )
429
- weights = ((name , data ) for name , data in weights
430
- if not name .startswith ("lm_head." ))
431
- self .model .load_weights (weights )
429
+ weights_list = list (weights )
430
+
431
+ has_model_prefix = any (
432
+ name .startswith ("model." ) for name , _ in weights_list )
433
+ if not has_model_prefix :
434
+ mapper = WeightsMapper (orig_to_new_prefix = {"" : "model." })
435
+
436
+ loader = AutoWeightsLoader (self , skip_prefixes = ["lm_head." ])
437
+ return loader .load_weights (weights_list , mapper = mapper )
432
438
433
439
def _build_model (self ,
434
440
vllm_config : VllmConfig ,
@@ -470,26 +476,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470
476
self .classifier , self .bert .pooler )
471
477
472
478
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
473
-
474
- self_weights = []
475
-
476
- def weight_filter ():
477
- for name , weight in weights :
478
- if name .startswith ("bert." ):
479
- yield (name [len ("bert." ):], weight )
480
- else :
481
- self_weights .append ((name , weight ))
482
-
483
- self .bert .load_weights (weight_filter ())
484
-
485
- params_dict = dict (self .named_parameters ())
486
-
487
- for name , loaded_weight in self_weights :
488
- if name .startswith ("classifier" ):
489
- param = params_dict [name ]
490
- weight_loader = getattr (param , "weight_loader" ,
491
- default_weight_loader )
492
- weight_loader (param , loaded_weight )
479
+ loader = AutoWeightsLoader (self )
480
+ loaded_params = loader .load_weights (weights )
481
+ return loaded_params
493
482
494
483
def pooler (
495
484
self ,
0 commit comments