Skip to content

Commit 175d2a8

Browse files
rahul-tuliclaude
andcommitted
feat: Add support for speculators Eagle checkpoints
- Add SpeculatorsEagleConfig to handle speculators config format - Update config loader to detect speculators Eagle models - Add weight name remapping in Eagle model load_weights - Support both standard Eagle and HASS (with layernorms) variants This enables vLLM to load Eagle models converted using the speculators library's checkpoint converter, mapping config fields and weight names to vLLM's expected format. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 9ff2af6 commit 175d2a8

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

vllm/model_executor/models/eagle.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
204204
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
205205
# Also, here's an example script for converting trained EAGLE
206206
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
207+
208+
# Support for speculators format weights
209+
speculators_name_map = {
210+
"fusion_fc.weight": "fc.weight",
211+
"fusion_fc.bias": "fc.bias",
212+
"embedding_layernorm.weight": "enorm.weight",
213+
"pre_lm_head_layernorm.weight": "hnorm.weight",
214+
}
215+
207216
model_weights = {}
208217
for name, loaded_weight in weights:
218+
# Handle speculators format weight names
219+
if name in speculators_name_map:
220+
name = speculators_name_map[name]
221+
elif name.startswith("transformer."):
222+
# transformer.* -> model.model.layers.0.*
223+
suffix = name[len("transformer."):]
224+
name = f"model.model.layers.0.{suffix}"
209225
if name == "token_map":
210226
if self.config.truncated_vocab_size < self.config.vocab_size:
211227
self.token_map = nn.Parameter(loaded_weight,

vllm/transformers_utils/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
NemotronConfig, NVLM_D_Config,
4141
OvisConfig, RWConfig,
4242
SkyworkR1VChatConfig, SolarConfig,
43+
SpeculatorsEagleConfig,
4344
Telechat2Config, UltravoxConfig)
4445
# yapf: enable
4546
from vllm.transformers_utils.configs.mistral import adapt_config_dict
47+
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config
4648
from vllm.transformers_utils.utils import check_gguf_file
4749
from vllm.utils import resolve_obj_by_qualname
4850

@@ -347,6 +349,17 @@ def get_config(
347349
raise ValueError(error_message) from e
348350

349351
if config_format == ConfigFormat.HF:
352+
# Check if this is a speculators Eagle model
353+
if is_speculators_eagle_config(model):
354+
config = SpeculatorsEagleConfig.from_pretrained(
355+
model,
356+
revision=revision,
357+
code_revision=code_revision,
358+
token=_get_hf_token(),
359+
**kwargs,
360+
)
361+
return config
362+
350363
config_dict, _ = PretrainedConfig.get_config_dict(
351364
model,
352365
revision=revision,

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
88
from vllm.transformers_utils.configs.eagle import EAGLEConfig
99
from vllm.transformers_utils.configs.exaone import ExaoneConfig
10+
from vllm.transformers_utils.configs.speculators_eagle import SpeculatorsEagleConfig
1011
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
1112
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
1213
# `FalconConfig` class from the official HuggingFace transformers library.
@@ -40,6 +41,7 @@
4041
"MedusaConfig",
4142
"EAGLEConfig",
4243
"ExaoneConfig",
44+
"SpeculatorsEagleConfig",
4345
"MiniMaxText01Config",
4446
"MiniMaxVL01Config",
4547
"MllamaConfig",
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
import os
6+
from pathlib import Path
7+
from typing import Optional, Union
8+
9+
from transformers import PretrainedConfig
10+
11+
from vllm.transformers_utils.configs.eagle import EAGLEConfig
12+
13+
14+
class SpeculatorsEagleConfig(EAGLEConfig):
15+
"""
16+
Adapter for speculators Eagle configs to make them compatible with vLLM.
17+
18+
This class handles the conversion between speculators config format and
19+
vLLM's expected Eagle config format.
20+
"""
21+
22+
@classmethod
23+
def from_pretrained(
24+
cls,
25+
pretrained_model_name_or_path: Union[str, os.PathLike],
26+
**kwargs,
27+
) -> "SpeculatorsEagleConfig":
28+
"""
29+
Load a speculators Eagle config and convert it to vLLM format.
30+
"""
31+
config_path = Path(pretrained_model_name_or_path) / "config.json"
32+
33+
if not config_path.exists():
34+
# Fall back to standard loading if not a local path
35+
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
36+
37+
with open(config_path, "r") as f:
38+
config_dict = json.load(f)
39+
40+
# Check if this is a speculators format config
41+
if "speculators_model_type" not in config_dict:
42+
# Not a speculators config, use standard loading
43+
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
44+
45+
# Convert speculators format to vLLM format
46+
vllm_config = cls._convert_speculators_to_vllm(config_dict)
47+
48+
return cls(**vllm_config)
49+
50+
@classmethod
51+
def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
52+
"""
53+
Convert speculators Eagle config format to vLLM format.
54+
55+
Speculators format:
56+
{
57+
"speculators_model_type": "eagle",
58+
"transformer_layer_config": {...},
59+
"layernorms": true/false,
60+
"fusion_bias": true/false
61+
}
62+
63+
vLLM format:
64+
{
65+
"model_type": "eagle",
66+
"model": {...},
67+
"eagle_fc_bias": true/false,
68+
"truncated_vocab_size": vocab_size
69+
}
70+
"""
71+
# Extract transformer config
72+
transformer_config = speculators_config.get("transformer_layer_config", {})
73+
74+
# Handle layernorms flag
75+
if speculators_config.get("layernorms", False):
76+
transformer_config["add_para_norm"] = True
77+
# Ensure skip flags are set correctly for extra layernorms
78+
transformer_config["skip_prenorm"] = False
79+
transformer_config["skip_output_norm"] = False
80+
81+
# Ensure transformer config has required fields
82+
if "architectures" not in transformer_config:
83+
# Infer from transformer_layer_architecture
84+
arch = speculators_config.get("transformer_layer_architecture", "LlamaDecoderLayer")
85+
if arch == "LlamaDecoderLayer":
86+
transformer_config["architectures"] = ["LlamaForCausalLM"]
87+
else:
88+
transformer_config["architectures"] = [arch]
89+
90+
# Build vLLM config
91+
vllm_config = {
92+
"model_type": "eagle",
93+
"model": transformer_config,
94+
"eagle_fc_bias": speculators_config.get("fusion_bias", False),
95+
"truncated_vocab_size": transformer_config.get("vocab_size"),
96+
}
97+
98+
# Preserve any additional fields that might be needed
99+
for key, value in speculators_config.items():
100+
if key not in ["speculators_model_type", "transformer_layer_config",
101+
"layernorms", "fusion_bias", "architectures"]:
102+
vllm_config[key] = value
103+
104+
# Set architectures for vLLM
105+
vllm_config["architectures"] = ["EAGLEModel"]
106+
107+
return vllm_config
108+
109+
110+
def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
111+
"""
112+
Check if a config file is in speculators Eagle format.
113+
"""
114+
config_file = Path(config_path) / "config.json"
115+
if not config_file.exists():
116+
return False
117+
118+
try:
119+
with open(config_file, "r") as f:
120+
config = json.load(f)
121+
return config.get("speculators_model_type") == "eagle"
122+
except:
123+
return False

0 commit comments

Comments
 (0)