From 81bf8b116db4726827b6527b6faddc9031906634 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:23:22 +0100 Subject: [PATCH 1/5] Separate base model from `TransformersModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/models/supported_models.md | 6 +- tests/distributed/test_pipeline_parallel.py | 2 +- tests/models/registry.py | 2 +- vllm/model_executor/model_loader/utils.py | 6 +- vllm/model_executor/models/registry.py | 4 +- vllm/model_executor/models/transformers.py | 124 +++++++++++++------- 6 files changed, 89 insertions(+), 55 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 56ea8c5d837..ccfa22f4d92 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersModel` then it means it's based on Transformers! +If it is `TransformersForCausalLM` then it means it's based on Transformers! :::{tip} -You can force the use of `TransformersModel` by setting `model_impl="transformers"` for or `--model-impl transformers` for the . +You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for or `--model-impl transformers` for the . ::: :::{note} @@ -119,7 +119,7 @@ Here is what happens in the background: 1. The config is loaded 2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. -3. The `TransformersModel` backend is used. See , which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. +3. The `TransformersForCausalLM` backend is used. See , which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. To make your model compatible with tensor parallel, it needs: diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index e757db45c8c..751c4eb096a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -175,7 +175,7 @@ def iter_params(self, model_id: str): "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), - # Tests TransformersModel + # Tests TransformersForCausalLM "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(), diff --git a/tests/models/registry.py b/tests/models/registry.py index 5c84e85aaa9..d7946b75b79 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -319,7 +319,7 @@ def check_available_online( } _FALLBACK_MODEL = { - "TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 } _EXAMPLE_MODELS = { diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ce906143297..a252c7f8e57 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -45,7 +45,7 @@ def is_transformers_impl_compatible( def resolve_transformers_fallback(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): - if arch == "TransformersModel": + if arch == "TransformersForCausalLM": continue auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", None) or dict() @@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig, raise ValueError( f"The Transformers implementation of {arch} is not " "compatible with vLLM.") - architectures[i] = "TransformersModel" + architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: if not is_transformers_impl_compatible(arch, custom_model_module): raise ValueError( @@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig, "%s has no vLLM implementation, falling back to Transformers " "implementation. Some features may not be supported and " "performance may not be optimal.", arch) - architectures[i] = "TransformersModel" + architectures[i] = "TransformersForCausalLM" return architectures diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7c8e5067138..7797d9a2cc2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -201,7 +201,7 @@ } _FALLBACK_MODEL = { - "TransformersModel": ("transformers", "TransformersModel"), + "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), } # yapf: enable @@ -425,7 +425,7 @@ def _normalize_archs( # make sure Transformers fallback are put at the last if len(normalized_arch) != len(architectures): - normalized_arch.append("TransformersModel") + normalized_arch.append("TransformersForCausalLM") return normalized_arch def inspect_model_cls( diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fe6a9d7a4aa..8aa460da684 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -42,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -109,12 +109,12 @@ def replace_linear_class( ) -class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): +class TransformersModel(nn.Module): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() logger.info("Using Transformers backend.") @@ -132,9 +132,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.parallel_config = parallel_config self.quant_config = quant_config - self.vocab_size = model_config.get_vocab_size() - self.unpadded_vocab_size = model_config.get_vocab_size() - self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size self.pp_rank = self.pp_group.rank_in_group @@ -148,7 +145,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: torch_dtype=model_config.dtype, trust_remote_code=model_config.trust_remote_code, ) - prefix = self.model.base_model_prefix self.pipeline_parallel() self.tensor_parallel() @@ -166,32 +162,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Attention layers self.attention_instances = self.create_attention_instances() - # Output embeddings - if not isinstance(getattr(self, "lm_head", None), PPMissingLayer): - self.unpadded_vocab_size = config.vocab_size - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings()) - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - # Initialize buffers (e.g. rotary embedding inverse frequency) self.init_buffers(self.model) # Move remaining meta tensors to device (should happen last) self.meta_to_empty(self.model) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) @@ -246,9 +222,6 @@ def pipeline_parallel(self): if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) - if not self.pp_group.is_last_rank: - self.lm_head = PPMissingLayer() - def tensor_parallel(self): """ Apply the model's tensor parallelization plan. @@ -329,6 +302,9 @@ def meta_to_empty(self, module: nn.Module): for child in module.children(): self.meta_to_empty(child) + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + def forward( self, input_ids: Optional[torch.Tensor], @@ -359,21 +335,6 @@ def forward( return hidden_states - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) @@ -391,3 +352,76 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + + self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + (name, loaded_weight) for name, loaded_weight in weights) From a5024f7571099b352b9c86a792bad10cf4fbca30 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 26 Mar 2025 00:55:11 +0800 Subject: [PATCH 2/5] fix weight loading Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 8aa460da684..7f1028c9f48 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -110,9 +110,6 @@ def replace_linear_class( class TransformersModel(nn.Module): - embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens" - ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -356,6 +353,9 @@ def load_weights(self, weights: Iterable[tuple[str, class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens" + ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -418,10 +418,14 @@ def sample(self, logits: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + def maybe_add_model_prefix(name: str): + return name if name.startswith("lm_head.") else "model." + name + loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( - (name, loaded_weight) for name, loaded_weight in weights) + (maybe_add_model_prefix(name), weight) for name, weight in weights) From 2e666f196a6f8674ad6f5ada7fbcb821253c0b20 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 25 Mar 2025 18:02:15 +0100 Subject: [PATCH 3/5] Add supports torch compile Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 7f1028c9f48..1ecd8269842 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -24,6 +24,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -351,6 +352,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params +@support_torch_compile class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] From 0ebbc8bde824ad565758347f716bdb3866dee152 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 26 Mar 2025 15:08:56 +0800 Subject: [PATCH 4/5] fix lora weights mapping Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 29 ++++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1ecd8269842..0fa0de9ebd7 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -137,6 +138,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Use meta device to delay allocating GPU tensors with torch.device("meta"): + # FIXME(Isotr0py): We need to refactor this part in the future to + # avoid registering an extra model layer, otherwise we will need a + # weights mapper to rename weights. self.model: PreTrainedModel = AutoModel.from_config( config, attn_implementation="vllm", @@ -356,7 +360,7 @@ def load_weights(self, weights: Iterable[tuple[str, class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens" + embedding_modules = ["model.embed_tokens" ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -392,6 +396,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + # FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback, + # this makes thing complicated. We need to remove this mapper after refactor + # `TransformersModel` in the future. + @property + def hf_to_vllm_mapper(self): + prefix_mapper = { + name: "model." + name + for name, _ in self.model.model.named_children() + } + return WeightsMapper( + orig_to_new_substr={"model.": "model.model."}, + orig_to_new_prefix=prefix_mapper, + ) + def forward( self, input_ids: Optional[torch.Tensor], @@ -420,14 +438,9 @@ def sample(self, logits: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - - def maybe_add_model_prefix(name: str): - return name if name.startswith("lm_head.") else "model." + name - loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights( - (maybe_add_model_prefix(name), weight) for name, weight in weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) From 62459deec2b7aef71e1a32b616bd16b616c0370b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 26 Mar 2025 15:14:53 +0800 Subject: [PATCH 5/5] oops Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 0fa0de9ebd7..6ea14950658 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -360,7 +360,7 @@ def load_weights(self, weights: Iterable[tuple[str, class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] - embedding_modules = ["model.embed_tokens" + embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):