1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Optional , Union
8
+
9
+ from transformers import PretrainedConfig
10
+
11
+ from vllm .transformers_utils .configs .eagle import EAGLEConfig
12
+
13
+
14
+ class SpeculatorsEagleConfig (EAGLEConfig ):
15
+ """
16
+ Adapter for speculators Eagle configs to make them compatible with vLLM.
17
+
18
+ This class handles the conversion between speculators config format and
19
+ vLLM's expected Eagle config format.
20
+ """
21
+
22
+ @classmethod
23
+ def from_pretrained (
24
+ cls ,
25
+ pretrained_model_name_or_path : Union [str , os .PathLike ],
26
+ ** kwargs ,
27
+ ) -> "SpeculatorsEagleConfig" :
28
+ """
29
+ Load a speculators Eagle config and convert it to vLLM format.
30
+ """
31
+ config_path = Path (pretrained_model_name_or_path ) / "config.json"
32
+
33
+ if not config_path .exists ():
34
+ # Fall back to standard loading if not a local path
35
+ return super ().from_pretrained (pretrained_model_name_or_path , ** kwargs )
36
+
37
+ with open (config_path , "r" ) as f :
38
+ config_dict = json .load (f )
39
+
40
+ # Check if this is a speculators format config
41
+ if "speculators_model_type" not in config_dict :
42
+ # Not a speculators config, use standard loading
43
+ return super ().from_pretrained (pretrained_model_name_or_path , ** kwargs )
44
+
45
+ # Convert speculators format to vLLM format
46
+ vllm_config = cls ._convert_speculators_to_vllm (config_dict )
47
+
48
+ return cls (** vllm_config )
49
+
50
+ @classmethod
51
+ def _convert_speculators_to_vllm (cls , speculators_config : dict ) -> dict :
52
+ """
53
+ Convert speculators Eagle config format to vLLM format.
54
+
55
+ Speculators format:
56
+ {
57
+ "speculators_model_type": "eagle",
58
+ "transformer_layer_config": {...},
59
+ "layernorms": true/false,
60
+ "fusion_bias": true/false
61
+ }
62
+
63
+ vLLM format:
64
+ {
65
+ "model_type": "eagle",
66
+ "model": {...},
67
+ "eagle_fc_bias": true/false,
68
+ "truncated_vocab_size": vocab_size
69
+ }
70
+ """
71
+ # Extract transformer config
72
+ transformer_config = speculators_config .get ("transformer_layer_config" , {})
73
+
74
+ # Handle layernorms flag
75
+ if speculators_config .get ("layernorms" , False ):
76
+ transformer_config ["add_para_norm" ] = True
77
+ # Ensure skip flags are set correctly for extra layernorms
78
+ transformer_config ["skip_prenorm" ] = False
79
+ transformer_config ["skip_output_norm" ] = False
80
+
81
+ # Ensure transformer config has required fields
82
+ if "architectures" not in transformer_config :
83
+ # Infer from transformer_layer_architecture
84
+ arch = speculators_config .get ("transformer_layer_architecture" , "LlamaDecoderLayer" )
85
+ if arch == "LlamaDecoderLayer" :
86
+ transformer_config ["architectures" ] = ["LlamaForCausalLM" ]
87
+ else :
88
+ transformer_config ["architectures" ] = [arch ]
89
+
90
+ # Build vLLM config
91
+ vllm_config = {
92
+ "model_type" : "eagle" ,
93
+ "model" : transformer_config ,
94
+ "eagle_fc_bias" : speculators_config .get ("fusion_bias" , False ),
95
+ "truncated_vocab_size" : transformer_config .get ("vocab_size" ),
96
+ }
97
+
98
+ # Preserve any additional fields that might be needed
99
+ for key , value in speculators_config .items ():
100
+ if key not in ["speculators_model_type" , "transformer_layer_config" ,
101
+ "layernorms" , "fusion_bias" , "architectures" ]:
102
+ vllm_config [key ] = value
103
+
104
+ # Set architectures for vLLM
105
+ vllm_config ["architectures" ] = ["EAGLEModel" ]
106
+
107
+ return vllm_config
108
+
109
+
110
+ def is_speculators_eagle_config (config_path : Union [str , os .PathLike ]) -> bool :
111
+ """
112
+ Check if a config file is in speculators Eagle format.
113
+ """
114
+ config_file = Path (config_path ) / "config.json"
115
+ if not config_file .exists ():
116
+ return False
117
+
118
+ try :
119
+ with open (config_file , "r" ) as f :
120
+ config = json .load (f )
121
+ return config .get ("speculators_model_type" ) == "eagle"
122
+ except :
123
+ return False
0 commit comments