Skip to content

feat: Add support for speculators Eagle checkpoints #20436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bf388d8
feat: Add support for speculators Eagle checkpoints
rahul-tuli Jul 3, 2025
b2888ab
cleanup: Remove unused imports from speculators_eagle.py
rahul-tuli Jul 3, 2025
43c098a
fix: Support HuggingFace model IDs in speculators Eagle config
rahul-tuli Jul 4, 2025
dea7fdf
fix: Add method field to speculators Eagle config for V1 compatibility
rahul-tuli Jul 9, 2025
31d9af6
fix: Use speculators_model_type for method field instead of hardcoding
rahul-tuli Jul 9, 2025
b7d286f
fix: Add eagle model_type detection for automatic method setting
rahul-tuli Jul 9, 2025
73a548c
fix: Add speculators Eagle detection as special case in V1 check
rahul-tuli Jul 9, 2025
c6b71b2
fix: Add speculators weight remapping to llama_eagle model
rahul-tuli Jul 9, 2025
4667abd
fix: Complete speculators Eagle support fixes
rahul-tuli Jul 9, 2025
b0f61c2
docs: Add comprehensive V1 engine Eagle support documentation
rahul-tuli Jul 10, 2025
e9bda92
feat: Add generic Eagle-3 speculators support
rahul-tuli Jul 15, 2025
eef118e
fix: Add HASS Eagle layernorm support for V1 engine
rahul-tuli Jul 15, 2025
9262a34
refactor: Clean up speculators Eagle config implementation
rahul-tuli Jul 15, 2025
e4e87fb
chore: Remove V1 engine Eagle support documentation
rahul-tuli Jul 15, 2025
7d4e0f2
refactor: Focus speculators Eagle support on V1 engine only
rahul-tuli Jul 15, 2025
95f6069
feat: Comprehensive code cleanup for speculators Eagle support
rahul-tuli Jul 15, 2025
ddd6123
refactor: Consolidate Eagle speculators weight mapping
rahul-tuli Jul 15, 2025
00da923
feat: Add support for Eagle models in speculators format
rahul-tuli Jul 15, 2025
d63ef14
remove changes to gitignore
rahul-tuli Jul 15, 2025
b905811
add back .gitignore
rahul-tuli Jul 15, 2025
7df8c9d
Add norm_before_residual support for llama_eagle3.py
rahul-tuli Jul 15, 2025
1408fb8
Fix bug
rahul-tuli Jul 15, 2025
46e398a
simplify logic
rahul-tuli Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2767,9 +2767,15 @@
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
elif hasattr(self.draft_model_config.hf_config,
"speculators_model_type") and \
self.draft_model_config.hf_config.speculators_model_type in ("eagle", "eagle3"):

Check failure on line 2772 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2772:81: E501 Line too long (104 > 80)
self.method = self.draft_model_config.hf_config.speculators_model_type

Check failure on line 2773 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2773:81: E501 Line too long (90 > 80)
elif "eagle-" in self.draft_model_config.model.lower() or \
"eagle3-" in self.draft_model_config.model.lower():

Check failure on line 2775 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2775:81: E501 Line too long (134 > 80)
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "eagle":
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
Expand Down
15 changes: 12 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config

Check failure on line 44 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:44:81: E501 Line too long (89 > 80)

# yapf: enable

Expand Down Expand Up @@ -1416,6 +1417,8 @@
if self.speculative_config is not None:
# This is supported but experimental (handled below).
speculative_method = self.speculative_config.get("method")
speculative_model = self.speculative_config.get("model")

if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
Expand All @@ -1424,9 +1427,15 @@
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
# If method is not set, try to detect from model
if speculative_model:
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
# Detect speculators format Eagle models which don't set the method

Check failure on line 1434 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:1434:81: E501 Line too long (87 > 80)
# field explicitly but can be identified by their config structure

Check failure on line 1435 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:1435:81: E501 Line too long (86 > 80)
elif is_speculators_eagle_config(speculative_model):
is_eagle_enabled = True

if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
Expand Down
55 changes: 51 additions & 4 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -11,6 +12,7 @@
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand All @@ -22,6 +24,27 @@

logger = init_logger(__name__)

# Map speculators weight names to vLLM names
SPECULATORS_WEIGHT_MAP = {
"fusion_fc.weight": "model.fc.weight",
"fusion_fc.bias": "model.fc.bias",
"embedding_layernorm.weight": "model.embedding_layernorm.weight",
"pre_lm_head_layernorm.weight": "model.hidden_states_layernorm.weight",
}


def remap_speculators_weight_name(name: str) -> Optional[str]:
"""Remap speculators format weight names to vLLM names.
Returns None for weights that should be skipped.
"""
if name in SPECULATORS_WEIGHT_MAP:
return SPECULATORS_WEIGHT_MAP[name]
elif name.startswith("transformer."):
# Replace "transformer." with "model.layers.0."
return "model.layers.0." + name[len("transformer."):]
return name


class LlamaDecoderLayer(LlamaDecoderLayer):

Expand Down Expand Up @@ -70,7 +93,15 @@ def __init__(
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
bias=getattr(self.config, "fusion_bias", False))

# HASS variant support
self.has_embedding_layernorms = getattr(self.config, "add_para_norm", False)
if self.has_embedding_layernorms:
self.embedding_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.hidden_states_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)

def forward(
self,
Expand All @@ -79,6 +110,12 @@ def forward(
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)

# Apply HASS normalization if enabled
if self.has_embedding_layernorms:
input_embeds = self.embedding_layernorm(input_embeds)
hidden_states = self.hidden_states_layernorm(hidden_states)

hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
Expand All @@ -104,6 +141,11 @@ def load_weights(self, weights: Iterable[tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
remapped_name = remap_speculators_weight_name(name)
if remapped_name is None:
continue
name = remapped_name

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand All @@ -119,6 +161,10 @@ def load_weights(self, weights: Iterable[tuple[str,
"embed_tokens." in name:
continue

# Skip weights that don't exist in the model
if name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down Expand Up @@ -159,7 +205,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
remapped_name = remap_speculators_weight_name(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be under a check where we first check if there's a speculators config present in self.config?

if remapped_name is None:
continue
model_weights[remapped_name] = loaded_weight
loader.load_weights(model_weights.items())
10 changes: 8 additions & 2 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
)

self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_before_residual = getattr(config, "norm_before_residual", False)

def forward(
self,
Expand All @@ -59,9 +60,14 @@ def forward(
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:

residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)

if self.norm_before_residual:
hidden_states = self.hidden_norm(hidden_states)
residual = hidden_states
else:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)

hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
Expand Down
15 changes: 15 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
NemotronConfig, NVLM_D_Config,
OvisConfig, RWConfig,
SkyworkR1VChatConfig, SolarConfig,
SpeculatorsEagleConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config

Check failure on line 47 in vllm/transformers_utils/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/transformers_utils/config.py:47:81: E501 Line too long (89 > 80)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname

Expand Down Expand Up @@ -350,6 +352,19 @@
raise ValueError(error_message) from e

if config_format == ConfigFormat.HF:
# Speculators Eagle models use a different config format that requires
# translation to vLLM's expected format. This must be handled before
# the standard config loading to ensure proper model initialization.
if is_speculators_eagle_config(model):
config = SpeculatorsEagleConfig.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all existing supported models just going through the PretrainedConfig pathway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
return config

config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.transformers_utils.configs.speculators_eagle import SpeculatorsEagleConfig

Check failure on line 10 in vllm/transformers_utils/configs/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/transformers_utils/configs/__init__.py:10:81: E501 Line too long (84 > 80)
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
Expand Down Expand Up @@ -40,6 +41,7 @@
"MedusaConfig",
"EAGLEConfig",
"ExaoneConfig",
"SpeculatorsEagleConfig",
"MiniMaxText01Config",
"MiniMaxVL01Config",
"MllamaConfig",
Expand Down
Loading
Loading