diff --git a/vllm/config.py b/vllm/config.py index 6c56ac1eec8..dafd924217a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2767,9 +2767,15 @@ def __post_init__(self): # Automatically detect the method if self.method in ('eagle', 'eagle3'): pass + elif hasattr(self.draft_model_config.hf_config, + "speculators_model_type") and \ + self.draft_model_config.hf_config.speculators_model_type in ("eagle", "eagle3"): + self.method = self.draft_model_config.hf_config.speculators_model_type elif "eagle-" in self.draft_model_config.model.lower() or \ "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "eagle": + self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif (self.draft_model_config.hf_config.model_type == diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7b73060e349..1cb30851cc0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,6 +41,7 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) +from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config # yapf: enable @@ -1416,6 +1417,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if self.speculative_config is not None: # This is supported but experimental (handled below). speculative_method = self.speculative_config.get("method") + speculative_model = self.speculative_config.get("model") + if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True @@ -1424,9 +1427,15 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True else: - speculative_model = self.speculative_config.get("model") - if speculative_model in ("ngram", "[ngram]"): - is_ngram_enabled = True + # If method is not set, try to detect from model + if speculative_model: + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True + # Detect speculators format Eagle models which don't set the method + # field explicitly but can be identified by their config structure + elif is_speculators_eagle_config(speculative_model): + is_eagle_enabled = True + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d..8f088f7dab2 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -11,6 +12,7 @@ from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -22,6 +24,27 @@ logger = init_logger(__name__) +# Map speculators weight names to vLLM names +SPECULATORS_WEIGHT_MAP = { + "fusion_fc.weight": "model.fc.weight", + "fusion_fc.bias": "model.fc.bias", + "embedding_layernorm.weight": "model.embedding_layernorm.weight", + "pre_lm_head_layernorm.weight": "model.hidden_states_layernorm.weight", +} + + +def remap_speculators_weight_name(name: str) -> Optional[str]: + """Remap speculators format weight names to vLLM names. + + Returns None for weights that should be skipped. + """ + if name in SPECULATORS_WEIGHT_MAP: + return SPECULATORS_WEIGHT_MAP[name] + elif name.startswith("transformer."): + # Replace "transformer." with "model.layers.0." + return "model.layers.0." + name[len("transformer."):] + return name + class LlamaDecoderLayer(LlamaDecoderLayer): @@ -70,7 +93,15 @@ def __init__( ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.config.hidden_size, - bias=False) + bias=getattr(self.config, "fusion_bias", False)) + + # HASS variant support + self.has_embedding_layernorms = getattr(self.config, "add_para_norm", False) + if self.has_embedding_layernorms: + self.embedding_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.hidden_states_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) def forward( self, @@ -79,6 +110,12 @@ def forward( hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) + + # Apply HASS normalization if enabled + if self.has_embedding_layernorms: + input_embeds = self.embedding_layernorm(input_embeds) + hidden_states = self.hidden_states_layernorm(hidden_states) + hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) residual = None @@ -104,6 +141,11 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + remapped_name = remap_speculators_weight_name(name) + if remapped_name is None: + continue + name = remapped_name + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -119,6 +161,10 @@ def load_weights(self, weights: Iterable[tuple[str, "embed_tokens." in name: continue + # Skip weights that don't exist in the model + if name not in params_dict: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -159,7 +205,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights = {} for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight + remapped_name = remap_speculators_weight_name(name) + if remapped_name is None: + continue + model_weights[remapped_name] = loaded_weight loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7fc9fe2ebb6..ec153c28b34 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -50,6 +50,7 @@ def __init__( ) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_before_residual = getattr(config, "norm_before_residual", False) def forward( self, @@ -59,9 +60,14 @@ def forward( residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - residual = hidden_states embeds = self.input_layernorm(embeds) - hidden_states = self.hidden_norm(hidden_states) + + if self.norm_before_residual: + hidden_states = self.hidden_norm(hidden_states) + residual = hidden_states + else: + residual = hidden_states + hidden_states = self.hidden_norm(hidden_states) hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cf3f519b027..222b4bba711 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -40,9 +40,11 @@ NemotronConfig, NVLM_D_Config, OvisConfig, RWConfig, SkyworkR1VChatConfig, SolarConfig, + SpeculatorsEagleConfig, Telechat2Config, UltravoxConfig) # yapf: enable from vllm.transformers_utils.configs.mistral import adapt_config_dict +from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -350,6 +352,19 @@ def get_config( raise ValueError(error_message) from e if config_format == ConfigFormat.HF: + # Speculators Eagle models use a different config format that requires + # translation to vLLM's expected format. This must be handled before + # the standard config loading to ensure proper model initialization. + if is_speculators_eagle_config(model): + config = SpeculatorsEagleConfig.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + return config + config_dict, _ = PretrainedConfig.get_config_dict( model, revision=revision, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 734f1e09d0f..0256969a83c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.transformers_utils.configs.speculators_eagle import SpeculatorsEagleConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -40,6 +41,7 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "SpeculatorsEagleConfig", "MiniMaxText01Config", "MiniMaxVL01Config", "MllamaConfig", diff --git a/vllm/transformers_utils/configs/speculators_eagle.py b/vllm/transformers_utils/configs/speculators_eagle.py new file mode 100644 index 00000000000..3994bcaea86 --- /dev/null +++ b/vllm/transformers_utils/configs/speculators_eagle.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Union + +from transformers import PretrainedConfig + +from vllm.transformers_utils.configs.eagle import EAGLEConfig + +# Constants for speculators format +SUPPORTED_SPECULATORS_TYPES = frozenset({"eagle", "eagle3"}) +DEFAULT_HIDDEN_SIZE = 4096 +DEFAULT_NUM_LOOKAHEAD_TOKENS = 5 + + +class SpeculatorsEagleConfig(EAGLEConfig): + """Configuration adapter for speculators Eagle models. + + Translates between speculators library format and vLLM's Eagle format. + Supports both Eagle-1 and Eagle-3 variants. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "SpeculatorsEagleConfig": + """Load speculators Eagle config and convert to vLLM format.""" + config_dict, _ = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + speculators_type = config_dict.get("speculators_model_type") + if speculators_type not in SUPPORTED_SPECULATORS_TYPES: + return super().from_pretrained( + pretrained_model_name_or_path, **kwargs + ) + + cls._validate_speculators_config(config_dict) + vllm_config = cls._convert_speculators_to_vllm(config_dict) + + return cls(**vllm_config) + + @classmethod + def _validate_speculators_config(cls, config: dict[str, Any]) -> None: + """Validate required speculators format fields.""" + # Check required top-level fields + if "speculators_model_type" not in config: + raise ValueError( + "Missing 'speculators_model_type' in config. " + f"Expected one of: {sorted(SUPPORTED_SPECULATORS_TYPES)}. " + "Please ensure you're loading a speculators-format Eagle model." + ) + + model_type = config["speculators_model_type"] + if model_type not in SUPPORTED_SPECULATORS_TYPES: + raise ValueError( + f"Unsupported speculators_model_type: '{model_type}'. " + f"Supported types: {sorted(SUPPORTED_SPECULATORS_TYPES)}" + ) + + # Check transformer config + if "transformer_layer_config" not in config: + raise ValueError( + "Missing 'transformer_layer_config' in speculators config. " + "This field should contain the transformer architecture configuration." + ) + + # Check proposal methods + speculators_cfg = config.get("speculators_config", {}) + if not isinstance(speculators_cfg, dict): + raise ValueError( + "'speculators_config' must be a dictionary. " + f"Got: {type(speculators_cfg).__name__}" + ) + + proposal_methods = speculators_cfg.get("proposal_methods", []) + if not proposal_methods: + raise ValueError( + "No proposal methods found in speculators_config. " + "Expected: {'speculators_config': {'proposal_methods': " + "[{'speculative_tokens': N}]}}. " + "Check that your model config follows the speculators format." + ) + + @classmethod + def _convert_speculators_to_vllm(cls, speculators_config: dict[str, Any]) -> dict[str, Any]: + """ + Convert speculators Eagle config format to vLLM format. + + This method handles the translation of field names and structure + between speculators and vLLM formats. It supports both Eagle-1 + and Eagle-3 variants based on speculators_model_type. + + Args: + speculators_config: Dictionary containing speculators format config + + Returns: + Dictionary with vLLM-compatible Eagle configuration + """ + speculators_model_type = speculators_config["speculators_model_type"] + transformer_config = speculators_config["transformer_layer_config"] + + # Extract num_lookahead_tokens from proposal_methods + num_lookahead_tokens = cls._extract_num_lookahead_tokens(speculators_config) + + # Build base vLLM config + vllm_config = { + "model": transformer_config, + "method": speculators_model_type, # Use speculators_model_type as method + "num_lookahead_tokens": num_lookahead_tokens, + } + + # Apply version-specific conversions + if speculators_model_type == "eagle": + cls._apply_eagle_v1_config(speculators_config, transformer_config, vllm_config) + elif speculators_model_type == "eagle3": + cls._apply_eagle_v3_config(speculators_config, transformer_config, vllm_config) + + # Ensure transformer config has required fields + cls._ensure_transformer_architectures(speculators_config, transformer_config) + + # Preserve additional fields not handled by specific conversions + cls._preserve_additional_fields(speculators_config, vllm_config) + + return vllm_config + + @classmethod + def _extract_num_lookahead_tokens(cls, config: dict[str, Any]) -> int: + """ + Extract number of lookahead tokens from proposal methods. + + Args: + config: Speculators config dictionary + + Returns: + Number of speculative tokens + + Note: + Currently only supports the first proposal method. + Future versions may support multiple proposal methods. + """ + speculators_cfg = config["speculators_config"] + proposal_methods = speculators_cfg["proposal_methods"] + + # Currently we only support one proposal method + first_method = proposal_methods[0] + num_lookahead_tokens = first_method.get("speculative_tokens") + + if num_lookahead_tokens is None: + raise ValueError( + "Missing 'speculative_tokens' in proposal method. " + f"Got: {first_method}" + ) + + return num_lookahead_tokens + + @classmethod + def _apply_eagle_v1_config( + cls, + speculators_config: dict[str, Any], + transformer_config: dict[str, Any], + vllm_config: dict[str, Any] + ) -> None: + """ + Apply Eagle-1 specific configuration transformations. + + Eagle-1 specific fields: + - fusion_bias → eagle_fc_bias + - layernorms → add_para_norm (for HASS variant) + - Uses truncated_vocab_size + """ + # Handle HASS variant with additional layernorms + if speculators_config.get("layernorms", False): + transformer_config["add_para_norm"] = True + # When using extra layernorms, ensure skip flags are set correctly + # to maintain the expected architecture behavior + transformer_config["skip_prenorm"] = False + transformer_config["skip_output_norm"] = False + + if speculators_config.get("fusion_bias", False): + # If fusion_bias is set, add it to the transformer config + transformer_config["fusion_bias"] = True + + + + # Map Eagle-1 specific fields + vocab_size = transformer_config.get("vocab_size") + vllm_config["truncated_vocab_size"] = vocab_size + vllm_config["architectures"] = ["EAGLEModel"] + + @classmethod + def _apply_eagle_v3_config( + cls, + speculators_config: dict[str, Any], + transformer_config: dict[str, Any], + vllm_config: dict[str, Any] + ) -> None: + """ + Apply Eagle-3 specific configuration transformations. + + Eagle-3 specific fields: + - draft_vocab_size: Size of the draft model's vocabulary + - target_hidden_size: Hidden size of the target model + - norm_before_residual: Whether to apply norm before residual connection + """ + # Copy Eagle-3 specific fields + if speculators_config.get("draft_vocab_size") is not None: + draft_vocab_size = speculators_config["draft_vocab_size"] + vllm_config["draft_vocab_size"] = draft_vocab_size + + # Handle target_hidden_size + if speculators_config.get("target_hidden_size") is not None: + target_hidden_size = speculators_config["target_hidden_size"] + vllm_config["target_hidden_size"] = target_hidden_size + else: + # Default to the draft model's hidden size + # In practice, this should match the target model's hidden size + vllm_config["target_hidden_size"] = transformer_config.get( + "hidden_size", DEFAULT_HIDDEN_SIZE + ) + + if "norm_before_residual" in speculators_config: + # Add to transformer config which becomes the model config + transformer_config["norm_before_residual"] = speculators_config["norm_before_residual"] + + # Eagle-3 uses a different architecture + vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + + @classmethod + def _ensure_transformer_architectures( + cls, + speculators_config: dict[str, Any], + transformer_config: dict[str, Any] + ) -> None: + """Ensure transformer config has required architecture field.""" + if "architectures" not in transformer_config: + default_arch = "LlamaDecoderLayer" + arch = speculators_config.get( + "transformer_layer_architecture", default_arch + ) + if arch == "LlamaDecoderLayer": + transformer_config["architectures"] = ["LlamaForCausalLM"] + else: + transformer_config["architectures"] = [arch] + + @classmethod + def _preserve_additional_fields( + cls, + speculators_config: dict[str, Any], + vllm_config: dict[str, Any] + ) -> None: + """Preserve additional fields for forward compatibility.""" + handled_fields = { + "speculators_model_type", + "transformer_layer_config", + "speculators_config", + "layernorms", + "fusion_bias", + "architectures", + "draft_vocab_size", + "target_hidden_size", + "norm_before_residual", + } + + for key, value in speculators_config.items(): + if key not in handled_fields: + vllm_config[key] = value + + +def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool: + """Check if a config file is in speculators Eagle format.""" + try: + config_dict, _ = PretrainedConfig.get_config_dict(config_path) + + if "speculators_model_type" not in config_dict: + return False + + model_type = config_dict.get("speculators_model_type") + return model_type in SUPPORTED_SPECULATORS_TYPES + except Exception: + return False