@@ -50,21 +50,27 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
50
50
51
51
Supports both Eagle and Eagle-3 models based on speculators_model_type.
52
52
"""
53
- speculators_type = speculators_config .get ("speculators_model_type" , "eagle " )
54
-
55
- # Extract transformer config
53
+ speculators_model_type = speculators_config .get ("speculators_model_type" )
54
+ assert speculators_model_type , "`speculators_model_type` must be specified in the config"
55
+
56
56
transformer_config = speculators_config .get ("transformer_layer_config" , {})
57
57
58
- # Build base vLLM config
58
+ # Extract num_lookahead_tokens from proposal_methods
59
+ proposal_methods = speculators_config .get ("speculators_config" , {}).get ("proposal_methods" , [])
60
+ assert proposal_methods , "speculators_config must have at least one proposal method"
61
+
62
+ # Only one proposal method is supported for now
63
+ proposal_method : dict = proposal_methods [0 ]
64
+ num_lookahead_tokens = proposal_method .get ("speculative_tokens" )
65
+ assert num_lookahead_tokens , "speculative_tokens must be specified in proposal_methods[0]"
66
+
59
67
vllm_config = {
60
- "model_type" : "eagle" ,
61
68
"model" : transformer_config ,
62
- "method" : speculators_type , # Use speculators_model_type as method
63
- "num_lookahead_tokens" : 5 , # Default number of speculative tokens
69
+ "method" : speculators_model_type ,
70
+ "num_lookahead_tokens" : num_lookahead_tokens ,
64
71
}
65
72
66
- # Handle version-specific config
67
- if speculators_type == "eagle" :
73
+ if speculators_model_type == "eagle" :
68
74
# Eagle-1 specific handling
69
75
# Handle layernorms flag
70
76
if speculators_config .get ("layernorms" , False ):
@@ -78,15 +84,15 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
78
84
vllm_config ["truncated_vocab_size" ] = transformer_config .get ("vocab_size" )
79
85
vllm_config ["architectures" ] = ["EAGLEModel" ]
80
86
81
- elif speculators_type == "eagle3" :
87
+ elif speculators_model_type == "eagle3" :
82
88
# Eagle-3 specific handling
83
89
# Copy Eagle-3 specific fields from speculators config
84
- if "draft_vocab_size" in speculators_config :
90
+ if speculators_config . get ( "draft_vocab_size" ) is not None :
85
91
vllm_config ["draft_vocab_size" ] = speculators_config ["draft_vocab_size" ]
86
92
87
93
# Handle target_hidden_size - if not provided, it should be set by vLLM
88
94
# based on the target model, but we can try to infer from transformer config
89
- if "target_hidden_size" in speculators_config and speculators_config [ "target_hidden_size" ] is not None :
95
+ if speculators_config . get ( "target_hidden_size" ) is not None :
90
96
vllm_config ["target_hidden_size" ] = speculators_config ["target_hidden_size" ]
91
97
else :
92
98
# Use the draft model's hidden size as target_hidden_size
@@ -108,25 +114,21 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
108
114
else :
109
115
transformer_config ["architectures" ] = [arch ]
110
116
117
+ speculators_specific_fields : set = {"speculators_model_type" , "transformer_layer_config" ,
118
+ "layernorms" , "fusion_bias" , "architectures" ,
119
+ "draft_vocab_size" , "target_hidden_size" , "norm_before_residual" }
120
+
111
121
# Preserve any additional fields that might be needed
112
122
for key , value in speculators_config .items ():
113
- if key not in ["speculators_model_type" , "transformer_layer_config" ,
114
- "layernorms" , "fusion_bias" , "architectures" ,
115
- "draft_vocab_size" , "target_hidden_size" , "norm_before_residual" ]:
123
+ if key not in speculators_specific_fields :
116
124
vllm_config [key ] = value
117
-
118
125
return vllm_config
119
126
120
127
121
128
def is_speculators_eagle_config (config_path : Union [str , os .PathLike ]) -> bool :
122
129
"""
123
130
Check if a config file is in speculators Eagle format.
124
131
"""
125
- try :
126
- # Use PretrainedConfig to load from both local and HF paths
127
- config_dict , _ = PretrainedConfig .get_config_dict (config_path )
128
- # Check for speculators format by looking for speculators_model_type key
129
- return "speculators_model_type" in config_dict and \
130
- config_dict .get ("speculators_model_type" ) in ["eagle" , "eagle3" ]
131
- except :
132
- return False
132
+ supported_model_types = ["eagle" , "eagle3" ]
133
+ config_dict , _ = PretrainedConfig .get_config_dict (config_path )
134
+ return config_dict .get ("speculators_model_type" ) in supported_model_types
0 commit comments