Skip to content

Commit 85bd659

Browse files
authored
[Model] Add AutoWeightsLoader support for BERT, RoBERTa (#20534)
Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
1 parent 91b3d19 commit 85bd659

File tree

2 files changed

+59
-100
lines changed

2 files changed

+59
-100
lines changed

vllm/model_executor/models/bert.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222
from vllm.model_executor.layers.quantization import QuantizationConfig
2323
from vllm.model_executor.layers.vocab_parallel_embedding import (
2424
VocabParallelEmbedding)
25-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2625
from vllm.model_executor.pooling_metadata import PoolingMetadata
2726
from vllm.sequence import IntermediateTensors, PoolerOutput
2827

2928
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
30-
from .utils import WeightsMapper, maybe_prefix
29+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
3130

3231

3332
class BertEmbedding(nn.Module):
@@ -44,9 +43,11 @@ def __init__(self, config: BertConfig):
4443
config.type_vocab_size, config.hidden_size)
4544
self.LayerNorm = nn.LayerNorm(config.hidden_size,
4645
eps=config.layer_norm_eps)
47-
self.position_ids = nn.Parameter(
48-
torch.empty((1, config.max_position_embeddings)), )
4946

47+
self.register_buffer(
48+
"position_ids",
49+
torch.arange(config.max_position_embeddings).unsqueeze(0),
50+
)
5051
self.position_embedding_type = config.position_embedding_type
5152
if self.position_embedding_type != "absolute":
5253
raise ValueError("Only 'absolute' position_embedding_type" +
@@ -358,45 +359,45 @@ def load_weights(self, weights: Iterable[tuple[str,
358359
("qkv_proj", "value", "v"),
359360
]
360361

362+
loaded_stacked_params = []
363+
other_weights = []
361364
params_dict = dict(self.named_parameters())
362-
loaded_params: set[str] = set()
363365
for name, loaded_weight in weights:
364-
if self.pooler is None and "pooler" in name:
365-
continue
366366
for (param_name, weight_name, shard_id) in stacked_params_mapping:
367367
if weight_name not in name:
368368
continue
369+
369370
name = name.replace(weight_name, param_name)
370-
# Skip loading extra bias for GPTQ models.
371-
if name.endswith(".bias") and name not in params_dict:
371+
if name not in params_dict:
372372
continue
373373
param = params_dict[name]
374374
weight_loader = param.weight_loader
375375
weight_loader(param, loaded_weight, shard_id)
376+
loaded_stacked_params.append(name)
376377
break
377378
else:
378-
# Skip loading extra bias for GPTQ models.
379-
if name.endswith(".bias") and name not in params_dict:
380-
continue
381-
param = params_dict[name]
382-
weight_loader = getattr(param, "weight_loader",
383-
default_weight_loader)
384-
weight_loader(param, loaded_weight)
385-
loaded_params.add(name)
379+
if name in params_dict:
380+
other_weights.append((name, loaded_weight))
381+
382+
loader = AutoWeightsLoader(
383+
self,
384+
skip_prefixes=(["pooler."] if self.pooler is None else []),
385+
)
386+
loaded_params = loader.load_weights(other_weights)
387+
loaded_params.update(loaded_stacked_params)
386388
return loaded_params
387389

388390

389391
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
390392
"""A model that uses Bert to provide embedding functionalities.
391393
392-
This class encapsulates the BertModel and provides an interface for
393-
embedding operations and customized pooling functions.
394+
This class encapsulates the BertModel and provides an interface for
395+
embedding operations and customized pooling functions.
394396
395-
Attributes:
396-
model: An instance of BertModel used for forward operations.
397-
_pooler: An instance of Pooler used for pooling operations.
398-
"""
399-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
397+
Attributes:
398+
model: An instance of BertModel used for forward operations.
399+
_pooler: An instance of Pooler used for pooling operations.
400+
"""
400401

401402
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
402403
super().__init__()
@@ -425,10 +426,15 @@ def pooler(
425426
return self._pooler(hidden_states, pooling_metadata)
426427

427428
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
428-
weights = self.hf_to_vllm_mapper.apply(weights)
429-
weights = ((name, data) for name, data in weights
430-
if not name.startswith("lm_head."))
431-
self.model.load_weights(weights)
429+
weights_list = list(weights)
430+
431+
has_model_prefix = any(
432+
name.startswith("model.") for name, _ in weights_list)
433+
if not has_model_prefix:
434+
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
435+
436+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
437+
return loader.load_weights(weights_list, mapper=mapper)
432438

433439
def _build_model(self,
434440
vllm_config: VllmConfig,
@@ -470,26 +476,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470476
self.classifier, self.bert.pooler)
471477

472478
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
473-
474-
self_weights = []
475-
476-
def weight_filter():
477-
for name, weight in weights:
478-
if name.startswith("bert."):
479-
yield (name[len("bert."):], weight)
480-
else:
481-
self_weights.append((name, weight))
482-
483-
self.bert.load_weights(weight_filter())
484-
485-
params_dict = dict(self.named_parameters())
486-
487-
for name, loaded_weight in self_weights:
488-
if name.startswith("classifier"):
489-
param = params_dict[name]
490-
weight_loader = getattr(param, "weight_loader",
491-
default_weight_loader)
492-
weight_loader(param, loaded_weight)
479+
loader = AutoWeightsLoader(self)
480+
loaded_params = loader.load_weights(weights)
481+
return loaded_params
493482

494483
def pooler(
495484
self,

vllm/model_executor/models/roberta.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import itertools
54
from collections.abc import Iterable
65
from typing import Optional, Union
76

@@ -13,9 +12,9 @@
1312
from vllm.model_executor.layers.pooler import ClassifierPooler
1413
from vllm.model_executor.layers.vocab_parallel_embedding import (
1514
VocabParallelEmbedding)
16-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1715
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
18-
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
16+
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
17+
maybe_prefix)
1918
from vllm.model_executor.pooling_metadata import PoolingMetadata
2019
from vllm.sequence import IntermediateTensors, PoolerOutput
2120

@@ -39,8 +38,10 @@ def __init__(self, config: RobertaConfig):
3938
config.hidden_size)
4039
self.LayerNorm = nn.LayerNorm(config.hidden_size,
4140
eps=config.layer_norm_eps)
42-
self.position_ids = nn.Parameter(
43-
torch.empty((1, config.max_position_embeddings)), )
41+
self.register_buffer(
42+
"position_ids",
43+
torch.arange(config.max_position_embeddings).unsqueeze(0),
44+
)
4445

4546
self.position_embedding_type = config.position_embedding_type
4647
if self.position_embedding_type != "absolute":
@@ -136,16 +137,20 @@ def _build_model(self,
136137
embedding_class=RobertaEmbedding)
137138

138139
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
139-
weights = self.hf_to_vllm_mapper.apply(weights)
140-
# Separate weights in "roberta"-prefixed and all else (not in memory).
141-
# For use with models like FacebookAI/roberta-base.
142-
bert_weights, task_weights = roberta_task_weights_filter(weights)
143-
loaded = self.model.load_weights(bert_weights)
144-
if not len(loaded):
145-
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
146-
# which use the same architecture, but have no "roberta" prefix.
147-
loaded = self.model.load_weights(task_weights)
148-
assert len(loaded), "Unable to load RobertaEmbeddingModel"
140+
weights_list = list(weights)
141+
has_roberta_prefix = any(
142+
name.startswith("roberta.") for name, _ in weights_list)
143+
if has_roberta_prefix:
144+
# For models with the `roberta.` prefix e.g.
145+
# `FacebookAI/roberta-base`
146+
mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
147+
else:
148+
# For models without the `roberta.` prefix e.g.
149+
# `sentence-transformers/stsb-roberta-base-v2`
150+
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
151+
152+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
153+
return loader.load_weights(weights_list, mapper=mapper)
149154

150155

151156
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
@@ -187,19 +192,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
187192
self.classifier)
188193

189194
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
190-
bert_weights, task_weights = roberta_task_weights_filter(weights)
191-
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
192-
193-
self.roberta.load_weights(bert_weights)
194-
195-
params_dict = dict(self.named_parameters())
196-
197-
for name, loaded_weight in task_weights:
198-
if name.startswith("classifier"):
199-
param = params_dict[name]
200-
weight_loader = getattr(param, "weight_loader",
201-
default_weight_loader)
202-
weight_loader(param, loaded_weight)
195+
loader = AutoWeightsLoader(self)
196+
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
203197

204198
def pooler(
205199
self,
@@ -245,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids,
245239
past_key_values_length) * mask
246240

247241
return incremental_indices.long() + padding_idx
248-
249-
250-
def roberta_task_weights_filter(
251-
all_weights: Iterable[tuple[str, torch.Tensor]]
252-
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
253-
torch.Tensor]]]:
254-
"""
255-
Separate task-specific weights that are applied on top
256-
of the encoder-decoder bert base.
257-
To do so, return two generators over the original iterator.
258-
Also, remove the "roberta." prefix to make it loadable
259-
from vanilla BertModel.
260-
"""
261-
# Copy of a lazy iterator without in-memory overhead so both
262-
# iterators can be iterated upon independently.
263-
all_weights1, all_weights2 = itertools.tee(all_weights)
264-
265-
def encoder_decoder_weights():
266-
for name, weight in all_weights1:
267-
if name.startswith("roberta."):
268-
yield (name[len("roberta."):], weight)
269-
270-
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
271-
if not n.startswith("roberta."))

0 commit comments

Comments
 (0)