Skip to content

Commit 5d158c6

Browse files
rahul-tuliclaude
andcommitted
feat: Add generic Eagle-3 speculators support
- Updated speculators config detection to check for speculators_model_type key - Support both eagle and eagle3 in is_speculators_eagle_config - Handle Eagle-3 specific config fields (draft_vocab_size, target_hidden_size) - Infer target_hidden_size from transformer config if not provided - Skip non-existent weights in llama_eagle to handle HASS models gracefully - Eagle-3 models don't need weight translation (already use correct names) This enables support for: - nm-testing/eagle3-llama3.1-8b-instruct-speculators - nm-testing/EAGLE3-LLaMA3.3-Instruct-70B-speculators While maintaining backward compatibility with Eagle-1 models. Signed-off-by: rtuli@redhat.com 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
1 parent b09d1bc commit 5d158c6

File tree

2 files changed

+57
-37
lines changed

2 files changed

+57
-37
lines changed

vllm/model_executor/models/llama_eagle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def load_weights(self, weights: Iterable[tuple[str,
135135
"embed_tokens." in name:
136136
continue
137137

138+
# Skip weights that don't exist in the model
139+
if name not in params_dict:
140+
continue
141+
138142
param = params_dict[name]
139143
weight_loader = getattr(param, "weight_loader",
140144
default_weight_loader)

vllm/transformers_utils/configs/speculators_eagle.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def from_pretrained(
3333
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
3434

3535
# Check if this is a speculators format config
36-
if config_dict.get("speculators_model_type") != "eagle":
36+
speculators_type = config_dict.get("speculators_model_type")
37+
if speculators_type not in ["eagle", "eagle3"]:
3738
# Not a speculators config, use standard loading
3839
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
3940

@@ -47,31 +48,56 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
4748
"""
4849
Convert speculators Eagle config format to vLLM format.
4950
50-
Speculators format:
51-
{
52-
"speculators_model_type": "eagle",
53-
"transformer_layer_config": {...},
54-
"layernorms": true/false,
55-
"fusion_bias": true/false
56-
}
57-
58-
vLLM format:
59-
{
60-
"model_type": "eagle",
61-
"model": {...},
62-
"eagle_fc_bias": true/false,
63-
"truncated_vocab_size": vocab_size
64-
}
51+
Supports both Eagle and Eagle-3 models based on speculators_model_type.
6552
"""
53+
speculators_type = speculators_config.get("speculators_model_type", "eagle")
54+
6655
# Extract transformer config
6756
transformer_config = speculators_config.get("transformer_layer_config", {})
6857

69-
# Handle layernorms flag
70-
if speculators_config.get("layernorms", False):
71-
transformer_config["add_para_norm"] = True
72-
# Ensure skip flags are set correctly for extra layernorms
73-
transformer_config["skip_prenorm"] = False
74-
transformer_config["skip_output_norm"] = False
58+
# Build base vLLM config
59+
vllm_config = {
60+
"model_type": "eagle",
61+
"model": transformer_config,
62+
"method": speculators_type, # Use speculators_model_type as method
63+
"num_lookahead_tokens": 5, # Default number of speculative tokens
64+
}
65+
66+
# Handle version-specific config
67+
if speculators_type == "eagle":
68+
# Eagle-1 specific handling
69+
# Handle layernorms flag
70+
if speculators_config.get("layernorms", False):
71+
transformer_config["add_para_norm"] = True
72+
# Ensure skip flags are set correctly for extra layernorms
73+
transformer_config["skip_prenorm"] = False
74+
transformer_config["skip_output_norm"] = False
75+
76+
# Eagle-1 specific fields
77+
vllm_config["eagle_fc_bias"] = speculators_config.get("fusion_bias", False)
78+
vllm_config["truncated_vocab_size"] = transformer_config.get("vocab_size")
79+
vllm_config["architectures"] = ["EAGLEModel"]
80+
81+
elif speculators_type == "eagle3":
82+
# Eagle-3 specific handling
83+
# Copy Eagle-3 specific fields from speculators config
84+
if "draft_vocab_size" in speculators_config:
85+
vllm_config["draft_vocab_size"] = speculators_config["draft_vocab_size"]
86+
87+
# Handle target_hidden_size - if not provided, it should be set by vLLM
88+
# based on the target model, but we can try to infer from transformer config
89+
if "target_hidden_size" in speculators_config and speculators_config["target_hidden_size"] is not None:
90+
vllm_config["target_hidden_size"] = speculators_config["target_hidden_size"]
91+
else:
92+
# Use the draft model's hidden size as target_hidden_size
93+
# This will be the same as the target model's hidden size
94+
vllm_config["target_hidden_size"] = transformer_config.get("hidden_size", 4096)
95+
96+
if "norm_before_residual" in speculators_config:
97+
vllm_config["norm_before_residual"] = speculators_config["norm_before_residual"]
98+
99+
# Eagle-3 uses different architecture
100+
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
75101

76102
# Ensure transformer config has required fields
77103
if "architectures" not in transformer_config:
@@ -82,25 +108,13 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
82108
else:
83109
transformer_config["architectures"] = [arch]
84110

85-
# Build vLLM config
86-
vllm_config = {
87-
"model_type": "eagle",
88-
"model": transformer_config,
89-
"eagle_fc_bias": speculators_config.get("fusion_bias", False),
90-
"truncated_vocab_size": transformer_config.get("vocab_size"),
91-
"method": speculators_config.get("speculators_model_type", "eagle"), # Use speculators_model_type
92-
"num_lookahead_tokens": 5, # Default number of speculative tokens for Eagle
93-
}
94-
95111
# Preserve any additional fields that might be needed
96112
for key, value in speculators_config.items():
97113
if key not in ["speculators_model_type", "transformer_layer_config",
98-
"layernorms", "fusion_bias", "architectures"]:
114+
"layernorms", "fusion_bias", "architectures",
115+
"draft_vocab_size", "target_hidden_size", "norm_before_residual"]:
99116
vllm_config[key] = value
100117

101-
# Set architectures for vLLM
102-
vllm_config["architectures"] = ["EAGLEModel"]
103-
104118
return vllm_config
105119

106120

@@ -111,6 +125,8 @@ def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
111125
try:
112126
# Use PretrainedConfig to load from both local and HF paths
113127
config_dict, _ = PretrainedConfig.get_config_dict(config_path)
114-
return config_dict.get("speculators_model_type") == "eagle"
128+
# Check for speculators format by looking for speculators_model_type key
129+
return "speculators_model_type" in config_dict and \
130+
config_dict.get("speculators_model_type") in ["eagle", "eagle3"]
115131
except:
116132
return False

0 commit comments

Comments
 (0)