-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
feat: Add support for speculators Eagle checkpoints #20436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bf388d8
b2888ab
43c098a
dea7fdf
31d9af6
b7d286f
73a548c
c6b71b2
4667abd
b0f61c2
e9bda92
eef118e
9262a34
e4e87fb
7d4e0f2
95f6069
ddd6123
00da923
d63ef14
b905811
7df8c9d
1408fb8
46e398a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2767,9 +2767,15 @@ | |
# Automatically detect the method | ||
if self.method in ('eagle', 'eagle3'): | ||
pass | ||
elif hasattr(self.draft_model_config.hf_config, | ||
"speculators_model_type") and \ | ||
self.draft_model_config.hf_config.speculators_model_type in ("eagle", "eagle3"): | ||
self.method = self.draft_model_config.hf_config.speculators_model_type | ||
elif "eagle-" in self.draft_model_config.model.lower() or \ | ||
"eagle3-" in self.draft_model_config.model.lower(): | ||
self.method = "eagle" | ||
elif self.draft_model_config.hf_config.model_type == "eagle": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
self.method = "eagle" | ||
elif self.draft_model_config.hf_config.model_type == "medusa": | ||
self.method = "medusa" | ||
elif (self.draft_model_config.hf_config.model_type == | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from collections.abc import Iterable | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -11,6 +12,7 @@ | |
from vllm.config import VllmConfig | ||
from vllm.distributed.parallel_state import get_pp_group | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
|
@@ -22,6 +24,27 @@ | |
|
||
logger = init_logger(__name__) | ||
|
||
# Map speculators weight names to vLLM names | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably live in the speculators config file |
||
SPECULATORS_WEIGHT_MAP = { | ||
"fusion_fc.weight": "model.fc.weight", | ||
"fusion_fc.bias": "model.fc.bias", | ||
"embedding_layernorm.weight": "model.embedding_layernorm.weight", | ||
"pre_lm_head_layernorm.weight": "model.hidden_states_layernorm.weight", | ||
} | ||
|
||
|
||
def remap_speculators_weight_name(name: str) -> Optional[str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
"""Remap speculators format weight names to vLLM names. | ||
|
||
Returns None for weights that should be skipped. | ||
""" | ||
if name in SPECULATORS_WEIGHT_MAP: | ||
return SPECULATORS_WEIGHT_MAP[name] | ||
elif name.startswith("transformer."): | ||
# Replace "transformer." with "model.layers.0." | ||
return "model.layers.0." + name[len("transformer."):] | ||
return name | ||
|
||
|
||
class LlamaDecoderLayer(LlamaDecoderLayer): | ||
|
||
|
@@ -70,7 +93,15 @@ def __init__( | |
]) | ||
self.fc = torch.nn.Linear(self.config.hidden_size * 2, | ||
self.config.hidden_size, | ||
bias=False) | ||
bias=getattr(self.config, "fusion_bias", False)) | ||
|
||
# HASS variant support | ||
self.has_embedding_layernorms = getattr(self.config, "add_para_norm", False) | ||
if self.has_embedding_layernorms: | ||
self.embedding_layernorm = RMSNorm(self.config.hidden_size, | ||
eps=self.config.rms_norm_eps) | ||
self.hidden_states_layernorm = RMSNorm(self.config.hidden_size, | ||
eps=self.config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
|
@@ -79,6 +110,12 @@ def forward( | |
hidden_states: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
input_embeds = self.embed_tokens(input_ids) | ||
|
||
# Apply HASS normalization if enabled | ||
if self.has_embedding_layernorms: | ||
input_embeds = self.embedding_layernorm(input_embeds) | ||
hidden_states = self.hidden_states_layernorm(hidden_states) | ||
|
||
hidden_states = self.fc( | ||
torch.cat((input_embeds, hidden_states), dim=-1)) | ||
residual = None | ||
|
@@ -104,6 +141,11 @@ def load_weights(self, weights: Iterable[tuple[str, | |
params_dict = dict(self.named_parameters()) | ||
loaded_params: set[str] = set() | ||
for name, loaded_weight in weights: | ||
remapped_name = remap_speculators_weight_name(name) | ||
if remapped_name is None: | ||
continue | ||
name = remapped_name | ||
|
||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
|
@@ -119,6 +161,10 @@ def load_weights(self, weights: Iterable[tuple[str, | |
"embed_tokens." in name: | ||
continue | ||
|
||
# Skip weights that don't exist in the model | ||
if name not in params_dict: | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
|
@@ -159,7 +205,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
|
||
model_weights = {} | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
remapped_name = remap_speculators_weight_name(name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be under a check where we first check if there's a speculators config present in self.config? |
||
if remapped_name is None: | ||
continue | ||
model_weights[remapped_name] = loaded_weight | ||
loader.load_weights(model_weights.items()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,11 @@ | |
NemotronConfig, NVLM_D_Config, | ||
OvisConfig, RWConfig, | ||
SkyworkR1VChatConfig, SolarConfig, | ||
SpeculatorsEagleConfig, | ||
Telechat2Config, UltravoxConfig) | ||
# yapf: enable | ||
from vllm.transformers_utils.configs.mistral import adapt_config_dict | ||
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config | ||
from vllm.transformers_utils.utils import check_gguf_file | ||
from vllm.utils import resolve_obj_by_qualname | ||
|
||
|
@@ -350,6 +352,19 @@ | |
raise ValueError(error_message) from e | ||
|
||
if config_format == ConfigFormat.HF: | ||
# Speculators Eagle models use a different config format that requires | ||
# translation to vLLM's expected format. This must be handled before | ||
# the standard config loading to ensure proper model initialization. | ||
if is_speculators_eagle_config(model): | ||
config = SpeculatorsEagleConfig.from_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all existing supported models just going through the PretrainedConfig pathway? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! |
||
model, | ||
revision=revision, | ||
code_revision=code_revision, | ||
token=_get_hf_token(), | ||
**kwargs, | ||
) | ||
return config | ||
|
||
config_dict, _ = PretrainedConfig.get_config_dict( | ||
model, | ||
revision=revision, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this