@@ -33,7 +33,8 @@ def from_pretrained(
33
33
config_dict , _ = cls .get_config_dict (pretrained_model_name_or_path , ** kwargs )
34
34
35
35
# Check if this is a speculators format config
36
- if config_dict .get ("speculators_model_type" ) != "eagle" :
36
+ speculators_type = config_dict .get ("speculators_model_type" )
37
+ if speculators_type not in ["eagle" , "eagle3" ]:
37
38
# Not a speculators config, use standard loading
38
39
return super ().from_pretrained (pretrained_model_name_or_path , ** kwargs )
39
40
@@ -47,31 +48,56 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
47
48
"""
48
49
Convert speculators Eagle config format to vLLM format.
49
50
50
- Speculators format:
51
- {
52
- "speculators_model_type": "eagle",
53
- "transformer_layer_config": {...},
54
- "layernorms": true/false,
55
- "fusion_bias": true/false
56
- }
57
-
58
- vLLM format:
59
- {
60
- "model_type": "eagle",
61
- "model": {...},
62
- "eagle_fc_bias": true/false,
63
- "truncated_vocab_size": vocab_size
64
- }
51
+ Supports both Eagle and Eagle-3 models based on speculators_model_type.
65
52
"""
53
+ speculators_type = speculators_config .get ("speculators_model_type" , "eagle" )
54
+
66
55
# Extract transformer config
67
56
transformer_config = speculators_config .get ("transformer_layer_config" , {})
68
57
69
- # Handle layernorms flag
70
- if speculators_config .get ("layernorms" , False ):
71
- transformer_config ["add_para_norm" ] = True
72
- # Ensure skip flags are set correctly for extra layernorms
73
- transformer_config ["skip_prenorm" ] = False
74
- transformer_config ["skip_output_norm" ] = False
58
+ # Build base vLLM config
59
+ vllm_config = {
60
+ "model_type" : "eagle" ,
61
+ "model" : transformer_config ,
62
+ "method" : speculators_type , # Use speculators_model_type as method
63
+ "num_lookahead_tokens" : 5 , # Default number of speculative tokens
64
+ }
65
+
66
+ # Handle version-specific config
67
+ if speculators_type == "eagle" :
68
+ # Eagle-1 specific handling
69
+ # Handle layernorms flag
70
+ if speculators_config .get ("layernorms" , False ):
71
+ transformer_config ["add_para_norm" ] = True
72
+ # Ensure skip flags are set correctly for extra layernorms
73
+ transformer_config ["skip_prenorm" ] = False
74
+ transformer_config ["skip_output_norm" ] = False
75
+
76
+ # Eagle-1 specific fields
77
+ vllm_config ["eagle_fc_bias" ] = speculators_config .get ("fusion_bias" , False )
78
+ vllm_config ["truncated_vocab_size" ] = transformer_config .get ("vocab_size" )
79
+ vllm_config ["architectures" ] = ["EAGLEModel" ]
80
+
81
+ elif speculators_type == "eagle3" :
82
+ # Eagle-3 specific handling
83
+ # Copy Eagle-3 specific fields from speculators config
84
+ if "draft_vocab_size" in speculators_config :
85
+ vllm_config ["draft_vocab_size" ] = speculators_config ["draft_vocab_size" ]
86
+
87
+ # Handle target_hidden_size - if not provided, it should be set by vLLM
88
+ # based on the target model, but we can try to infer from transformer config
89
+ if "target_hidden_size" in speculators_config and speculators_config ["target_hidden_size" ] is not None :
90
+ vllm_config ["target_hidden_size" ] = speculators_config ["target_hidden_size" ]
91
+ else :
92
+ # Use the draft model's hidden size as target_hidden_size
93
+ # This will be the same as the target model's hidden size
94
+ vllm_config ["target_hidden_size" ] = transformer_config .get ("hidden_size" , 4096 )
95
+
96
+ if "norm_before_residual" in speculators_config :
97
+ vllm_config ["norm_before_residual" ] = speculators_config ["norm_before_residual" ]
98
+
99
+ # Eagle-3 uses different architecture
100
+ vllm_config ["architectures" ] = ["Eagle3LlamaForCausalLM" ]
75
101
76
102
# Ensure transformer config has required fields
77
103
if "architectures" not in transformer_config :
@@ -82,25 +108,13 @@ def _convert_speculators_to_vllm(cls, speculators_config: dict) -> dict:
82
108
else :
83
109
transformer_config ["architectures" ] = [arch ]
84
110
85
- # Build vLLM config
86
- vllm_config = {
87
- "model_type" : "eagle" ,
88
- "model" : transformer_config ,
89
- "eagle_fc_bias" : speculators_config .get ("fusion_bias" , False ),
90
- "truncated_vocab_size" : transformer_config .get ("vocab_size" ),
91
- "method" : speculators_config .get ("speculators_model_type" , "eagle" ), # Use speculators_model_type
92
- "num_lookahead_tokens" : 5 , # Default number of speculative tokens for Eagle
93
- }
94
-
95
111
# Preserve any additional fields that might be needed
96
112
for key , value in speculators_config .items ():
97
113
if key not in ["speculators_model_type" , "transformer_layer_config" ,
98
- "layernorms" , "fusion_bias" , "architectures" ]:
114
+ "layernorms" , "fusion_bias" , "architectures" ,
115
+ "draft_vocab_size" , "target_hidden_size" , "norm_before_residual" ]:
99
116
vllm_config [key ] = value
100
117
101
- # Set architectures for vLLM
102
- vllm_config ["architectures" ] = ["EAGLEModel" ]
103
-
104
118
return vllm_config
105
119
106
120
@@ -111,6 +125,8 @@ def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
111
125
try :
112
126
# Use PretrainedConfig to load from both local and HF paths
113
127
config_dict , _ = PretrainedConfig .get_config_dict (config_path )
114
- return config_dict .get ("speculators_model_type" ) == "eagle"
128
+ # Check for speculators format by looking for speculators_model_type key
129
+ return "speculators_model_type" in config_dict and \
130
+ config_dict .get ("speculators_model_type" ) in ["eagle" , "eagle3" ]
115
131
except :
116
132
return False
0 commit comments