Skip to content

Commit a3ce5fa

Browse files
rahul-tuliclaude
andcommitted
refactor: Clean up speculators Eagle config implementation
- Remove redundant model_type field from vllm_config (already defined in EAGLEConfig) - Extract num_lookahead_tokens from proposal_methods in speculators config - Add proper assertions for required speculators config structure - Remove unnecessary intermediate variable speculators_cfg 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
1 parent 3dea5ef commit a3ce5fa

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

vllm/transformers_utils/configs/speculators_eagle.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,27 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
5050
5151
Supports both Eagle and Eagle-3 models based on speculators_model_type.
5252
"""
53-
speculators_type = speculators_config.get("speculators_model_type", "eagle")
54-
55-
# Extract transformer config
53+
speculators_model_type = speculators_config.get("speculators_model_type")
54+
assert speculators_model_type, "`speculators_model_type` must be specified in the config"
55+
5656
transformer_config = speculators_config.get("transformer_layer_config", {})
5757

58-
# Build base vLLM config
58+
# Extract num_lookahead_tokens from proposal_methods
59+
proposal_methods = speculators_config.get("speculators_config", {}).get("proposal_methods", [])
60+
assert proposal_methods, "speculators_config must have at least one proposal method"
61+
62+
# Only one proposal method is supported for now
63+
proposal_method: dict = proposal_methods[0]
64+
num_lookahead_tokens = proposal_method.get("speculative_tokens")
65+
assert num_lookahead_tokens, "speculative_tokens must be specified in proposal_methods[0]"
66+
5967
vllm_config = {
60-
"model_type": "eagle",
6168
"model": transformer_config,
62-
"method": speculators_type, # Use speculators_model_type as method
63-
"num_lookahead_tokens": 5, # Default number of speculative tokens
69+
"method": speculators_model_type,
70+
"num_lookahead_tokens": num_lookahead_tokens,
6471
}
6572

66-
# Handle version-specific config
67-
if speculators_type == "eagle":
73+
if speculators_model_type == "eagle":
6874
# Eagle-1 specific handling
6975
# Handle layernorms flag
7076
if speculators_config.get("layernorms", False):
@@ -78,15 +84,15 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
7884
vllm_config["truncated_vocab_size"] = transformer_config.get("vocab_size")
7985
vllm_config["architectures"] = ["EAGLEModel"]
8086

81-
elif speculators_type == "eagle3":
87+
elif speculators_model_type == "eagle3":
8288
# Eagle-3 specific handling
8389
# Copy Eagle-3 specific fields from speculators config
84-
if "draft_vocab_size" in speculators_config:
90+
if speculators_config.get("draft_vocab_size") is not None:
8591
vllm_config["draft_vocab_size"] = speculators_config["draft_vocab_size"]
8692

8793
# Handle target_hidden_size - if not provided, it should be set by vLLM
8894
# based on the target model, but we can try to infer from transformer config
89-
if "target_hidden_size" in speculators_config and speculators_config["target_hidden_size"] is not None:
95+
if speculators_config.get("target_hidden_size") is not None:
9096
vllm_config["target_hidden_size"] = speculators_config["target_hidden_size"]
9197
else:
9298
# Use the draft model's hidden size as target_hidden_size
@@ -108,25 +114,21 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
108114
else:
109115
transformer_config["architectures"] = [arch]
110116

117+
speculators_specific_fields: set = {"speculators_model_type", "transformer_layer_config",
118+
"layernorms", "fusion_bias", "architectures",
119+
"draft_vocab_size", "target_hidden_size", "norm_before_residual"}
120+
111121
# Preserve any additional fields that might be needed
112122
for key, value in speculators_config.items():
113-
if key not in ["speculators_model_type", "transformer_layer_config",
114-
"layernorms", "fusion_bias", "architectures",
115-
"draft_vocab_size", "target_hidden_size", "norm_before_residual"]:
123+
if key not in speculators_specific_fields:
116124
vllm_config[key] = value
117-
118125
return vllm_config
119126

120127

121128
def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
122129
"""
123130
Check if a config file is in speculators Eagle format.
124131
"""
125-
try:
126-
# Use PretrainedConfig to load from both local and HF paths
127-
config_dict, _ = PretrainedConfig.get_config_dict(config_path)
128-
# Check for speculators format by looking for speculators_model_type key
129-
return "speculators_model_type" in config_dict and \
130-
config_dict.get("speculators_model_type") in ["eagle", "eagle3"]
131-
except:
132-
return False
132+
supported_model_types = ["eagle", "eagle3"]
133+
config_dict, _ = PretrainedConfig.get_config_dict(config_path)
134+
return config_dict.get("speculators_model_type") in supported_model_types

0 commit comments

Comments
 (0)