Skip to content

Commit 0736b04

Browse files
hmellorIsotr0py
authored andcommitted
Separate base model from TransformersModel (vllm-project#15467)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 041e217 commit 0736b04

File tree

6 files changed

+110
-59
lines changed

6 files changed

+110
-59
lines changed

docs/source/models/supported_models.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
5757
llm.apply_model(lambda model: print(type(model)))
5858
```
5959

60-
If it is `TransformersModel` then it means it's based on Transformers!
60+
If it is `TransformersForCausalLM` then it means it's based on Transformers!
6161

6262
:::{tip}
63-
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
63+
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
6464
:::
6565

6666
:::{note}
@@ -119,7 +119,7 @@ Here is what happens in the background:
119119

120120
1. The config is loaded
121121
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
122-
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
122+
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
123123

124124
To make your model compatible with tensor parallel, it needs:
125125

tests/distributed/test_pipeline_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def iter_params(self, model_id: str):
175175
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
176176
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
177177
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
178-
# Tests TransformersModel
178+
# Tests TransformersForCausalLM
179179
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
180180
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
181181
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def check_available_online(
319319
}
320320

321321
_FALLBACK_MODEL = {
322-
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
322+
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
323323
}
324324

325325
_EXAMPLE_MODELS = {

vllm/model_executor/model_loader/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def is_transformers_impl_compatible(
4545
def resolve_transformers_fallback(model_config: ModelConfig,
4646
architectures: list[str]):
4747
for i, arch in enumerate(architectures):
48-
if arch == "TransformersModel":
48+
if arch == "TransformersForCausalLM":
4949
continue
5050
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
5151
None) or dict()
@@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
6969
raise ValueError(
7070
f"The Transformers implementation of {arch} is not "
7171
"compatible with vLLM.")
72-
architectures[i] = "TransformersModel"
72+
architectures[i] = "TransformersForCausalLM"
7373
if model_config.model_impl == ModelImpl.AUTO:
7474
if not is_transformers_impl_compatible(arch, custom_model_module):
7575
raise ValueError(
@@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
8080
"%s has no vLLM implementation, falling back to Transformers "
8181
"implementation. Some features may not be supported and "
8282
"performance may not be optimal.", arch)
83-
architectures[i] = "TransformersModel"
83+
architectures[i] = "TransformersForCausalLM"
8484
return architectures
8585

8686

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
}
202202

203203
_FALLBACK_MODEL = {
204-
"TransformersModel": ("transformers", "TransformersModel"),
204+
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
205205
}
206206
# yapf: enable
207207

@@ -425,7 +425,7 @@ def _normalize_archs(
425425

426426
# make sure Transformers fallback are put at the last
427427
if len(normalized_arch) != len(architectures):
428-
normalized_arch.append("TransformersModel")
428+
normalized_arch.append("TransformersForCausalLM")
429429
return normalized_arch
430430

431431
def inspect_model_cls(

vllm/model_executor/models/transformers.py

Lines changed: 100 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
46-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
47+
is_pp_missing_parameter,
4748
make_empty_intermediate_tensors_factory, maybe_prefix)
4849

4950
logger = init_logger(__name__)
@@ -110,13 +111,9 @@ def replace_linear_class(
110111
)
111112

112113

113-
@support_torch_compile
114-
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
115-
embedding_padding_modules = ["lm_head"]
116-
embedding_modules = ["embed_tokens"
117-
] # TODO transformers will have a util to get it
114+
class TransformersModel(nn.Module):
118115

119-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
116+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
120117
super().__init__()
121118
logger.info("Using Transformers backend.")
122119

@@ -134,23 +131,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
134131
self.parallel_config = parallel_config
135132
self.quant_config = quant_config
136133

137-
self.vocab_size = model_config.get_vocab_size()
138-
self.unpadded_vocab_size = model_config.get_vocab_size()
139-
140134
self.pp_group = get_pp_group()
141135
self.pp_size = self.pp_group.world_size
142136
self.pp_rank = self.pp_group.rank_in_group
143137
self.tp_size = get_tensor_model_parallel_world_size()
144138

145139
# Use meta device to delay allocating GPU tensors
146140
with torch.device("meta"):
141+
# FIXME(Isotr0py): We need to refactor this part in the future to
142+
# avoid registering an extra model layer, otherwise we will need a
143+
# weights mapper to rename weights.
147144
self.model: PreTrainedModel = AutoModel.from_config(
148145
config,
149146
attn_implementation="vllm",
150147
torch_dtype=model_config.dtype,
151148
trust_remote_code=model_config.trust_remote_code,
152149
)
153-
prefix = self.model.base_model_prefix
154150

155151
self.pipeline_parallel()
156152
self.tensor_parallel()
@@ -168,32 +164,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
168164
# Attention layers
169165
self.attention_instances = self.create_attention_instances()
170166

171-
# Output embeddings
172-
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
173-
self.unpadded_vocab_size = config.vocab_size
174-
self.lm_head = ParallelLMHead(
175-
config.vocab_size,
176-
config.hidden_size,
177-
quant_config=quant_config,
178-
prefix=maybe_prefix(prefix, "lm_head"),
179-
)
180-
if config.tie_word_embeddings:
181-
self.lm_head = self.lm_head.tie_weights(
182-
self.model.get_input_embeddings())
183-
184-
logit_scale = getattr(config, "logit_scale", 1.0)
185-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
186-
config.vocab_size,
187-
logit_scale)
188-
189167
# Initialize buffers (e.g. rotary embedding inverse frequency)
190168
self.init_buffers(self.model)
191169

192170
# Move remaining meta tensors to device (should happen last)
193171
self.meta_to_empty(self.model)
194172

195-
self.sampler = get_sampler()
196-
197173
self.make_empty_intermediate_tensors = (
198174
make_empty_intermediate_tensors_factory(["hidden_states"],
199175
config.hidden_size))
@@ -248,9 +224,6 @@ def pipeline_parallel(self):
248224
if not self.pp_group.is_last_rank:
249225
setattr(self.model, name, PPMissingLayer())
250226

251-
if not self.pp_group.is_last_rank:
252-
self.lm_head = PPMissingLayer()
253-
254227
def tensor_parallel(self):
255228
"""
256229
Apply the model's tensor parallelization plan.
@@ -331,6 +304,9 @@ def meta_to_empty(self, module: nn.Module):
331304
for child in module.children():
332305
self.meta_to_empty(child)
333306

307+
def get_input_embeddings(self) -> nn.Module:
308+
return self.model.get_input_embeddings()
309+
334310
def forward(
335311
self,
336312
input_ids: Optional[torch.Tensor],
@@ -361,21 +337,6 @@ def forward(
361337

362338
return hidden_states
363339

364-
def compute_logits(
365-
self,
366-
hidden_states: torch.Tensor,
367-
sampling_metadata: SamplingMetadata,
368-
) -> Optional[torch.Tensor]:
369-
logits = self.logits_processor(self.lm_head, hidden_states,
370-
sampling_metadata)
371-
return logits
372-
373-
def sample(self, logits: torch.Tensor,
374-
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
375-
376-
next_tokens = self.sampler(logits, sampling_metadata)
377-
return next_tokens
378-
379340
def load_weights(self, weights: Iterable[tuple[str,
380341
torch.Tensor]]) -> set[str]:
381342
params_dict = dict(self.named_parameters())
@@ -393,3 +354,93 @@ def load_weights(self, weights: Iterable[tuple[str,
393354
weight_loader(param, loaded_weight)
394355
loaded_params.add(name)
395356
return loaded_params
357+
358+
359+
@support_torch_compile
360+
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
361+
SupportsPP):
362+
embedding_padding_modules = ["lm_head"]
363+
embedding_modules = ["embed_tokens"
364+
] # TODO transformers will have a util to get it
365+
366+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
367+
super().__init__()
368+
config: PretrainedConfig = vllm_config.model_config.hf_config
369+
quant_config: QuantizationConfig = vllm_config.quant_config
370+
371+
self.config = config
372+
373+
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
374+
375+
if get_pp_group().is_last_rank:
376+
self.unpadded_vocab_size = config.vocab_size
377+
self.lm_head = ParallelLMHead(
378+
config.vocab_size,
379+
config.hidden_size,
380+
quant_config=quant_config,
381+
prefix=maybe_prefix(prefix, "lm_head"),
382+
)
383+
if config.tie_word_embeddings:
384+
self.lm_head = self.lm_head.tie_weights(
385+
self.model.get_input_embeddings())
386+
387+
logit_scale = getattr(config, "logit_scale", 1.0)
388+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
389+
config.vocab_size,
390+
logit_scale)
391+
else:
392+
self.lm_head = PPMissingLayer()
393+
394+
self.sampler = get_sampler()
395+
396+
self.make_empty_intermediate_tensors = (
397+
self.model.make_empty_intermediate_tensors)
398+
399+
# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
400+
# this makes thing complicated. We need to remove this mapper after refactor
401+
# `TransformersModel` in the future.
402+
@property
403+
def hf_to_vllm_mapper(self):
404+
prefix_mapper = {
405+
name: "model." + name
406+
for name, _ in self.model.model.named_children()
407+
}
408+
return WeightsMapper(
409+
orig_to_new_substr={"model.": "model.model."},
410+
orig_to_new_prefix=prefix_mapper,
411+
)
412+
413+
def forward(
414+
self,
415+
input_ids: Optional[torch.Tensor],
416+
positions: torch.Tensor,
417+
intermediate_tensors: Optional[IntermediateTensors] = None,
418+
inputs_embeds: Optional[torch.Tensor] = None,
419+
) -> Union[torch.Tensor, IntermediateTensors]:
420+
model_output = self.model(input_ids, positions, intermediate_tensors,
421+
inputs_embeds)
422+
return model_output
423+
424+
def compute_logits(
425+
self,
426+
hidden_states: torch.Tensor,
427+
sampling_metadata: SamplingMetadata,
428+
) -> Optional[torch.Tensor]:
429+
logits = self.logits_processor(self.lm_head, hidden_states,
430+
sampling_metadata)
431+
return logits
432+
433+
def sample(self, logits: torch.Tensor,
434+
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
435+
436+
next_tokens = self.sampler(logits, sampling_metadata)
437+
return next_tokens
438+
439+
def load_weights(self, weights: Iterable[tuple[str,
440+
torch.Tensor]]) -> set[str]:
441+
loader = AutoWeightsLoader(
442+
self,
443+
skip_prefixes=(["lm_head."]
444+
if self.config.tie_word_embeddings else None),
445+
)
446+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 commit comments

Comments
 (0)