Skip to content

Separate base model from TransformersModel #15467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```

If it is `TransformersModel` then it means it's based on Transformers!
If it is `TransformersForCausalLM` then it means it's based on Transformers!

:::{tip}
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
:::

:::{note}
Expand Down Expand Up @@ -119,7 +119,7 @@ Here is what happens in the background:

1. The config is loaded
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.

To make your model compatible with tensor parallel, it needs:

Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def iter_params(self, model_id: str):
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel
# Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
Expand Down
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def check_available_online(
}

_FALLBACK_MODEL = {
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
}

_EXAMPLE_MODELS = {
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def is_transformers_impl_compatible(
def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
if arch == "TransformersForCausalLM":
continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
Expand All @@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
Expand All @@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
return architectures


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
}

_FALLBACK_MODEL = {
"TransformersModel": ("transformers", "TransformersModel"),
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable

Expand Down Expand Up @@ -425,7 +425,7 @@ def _normalize_archs(

# make sure Transformers fallback are put at the last
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersModel")
normalized_arch.append("TransformersForCausalLM")
return normalized_arch

def inspect_model_cls(
Expand Down
149 changes: 100 additions & 49 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (PPMissingLayer, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)

logger = init_logger(__name__)
Expand Down Expand Up @@ -110,13 +111,9 @@ def replace_linear_class(
)


@support_torch_compile
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
class TransformersModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")

Expand All @@ -134,23 +131,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.parallel_config = parallel_config
self.quant_config = quant_config

self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()

self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()

# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
attn_implementation="vllm",
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix

self.pipeline_parallel()
self.tensor_parallel()
Expand All @@ -168,32 +164,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
# Attention layers
self.attention_instances = self.create_attention_instances()

# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)

# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)

# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
Expand Down Expand Up @@ -248,9 +224,6 @@ def pipeline_parallel(self):
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())

if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()

def tensor_parallel(self):
"""
Apply the model's tensor parallelization plan.
Expand Down Expand Up @@ -331,6 +304,9 @@ def meta_to_empty(self, module: nn.Module):
for child in module.children():
self.meta_to_empty(child)

def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down Expand Up @@ -361,21 +337,6 @@ def forward(

return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:

next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
Expand All @@ -393,3 +354,93 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config

self.config = config

self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:

next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)