1
+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar , Union
2
+
3
+ import vllm .envs as envs
4
+ from transformers import PretrainedConfig
5
+ from vllm .config import ModelConfig , SpeculativeConfig
6
+
7
+ if TYPE_CHECKING :
8
+ from _typeshed import DataclassInstance
9
+
10
+ ConfigType = type [DataclassInstance ]
11
+ else :
12
+ ConfigType = type
13
+
14
+ ConfigT = TypeVar ("ConfigT" , bound = ConfigType )
15
+
16
+ TaskOption = Literal ["auto" , "generate" , "embedding" , "embed" , "classify" ,
17
+ "score" , "reward" , "transcription" ]
18
+
19
+ RunnerType = Literal ["generate" , "pooling" , "draft" , "transcription" ]
20
+
21
+ HfOverrides = Union [dict [str , Any ], Callable [[PretrainedConfig ],
22
+ PretrainedConfig ]]
23
+
24
+
25
+ def __post_init__ (self ):
26
+
27
+ # Note: "method" is a new parameter that helps to extend the
28
+ # configuration of non-model-based proposers, and the "model" parameter
29
+ # will be used to set the draft model, eagle head, or additional weight
30
+ # when needed. If users do not specify "method", the speculative method
31
+ # will be detected automatically if possible. If the speculative method
32
+ # can not be detected, it will be considered as the "draft_model" by
33
+ # default.
34
+
35
+ if self .model is None and self .num_speculative_tokens is not None :
36
+ # TODO(Shangming): Refactor mtp configuration logic when supporting
37
+ # mtp acceleration for more models besides deepseek_v3
38
+ if self .target_model_config and \
39
+ (self .target_model_config .hf_text_config .model_type \
40
+ == "deepseek_v3" or
41
+ self .target_model_config .hf_text_config .model_type \
42
+ == "mimo" ):
43
+ # use the draft model from the same model:
44
+ self .model = self .target_model_config .model
45
+ elif self .method in ("ngram" , "[ngram]" ):
46
+ self .model = "ngram"
47
+ else :
48
+ raise ValueError ("num_speculative_tokens was provided without "
49
+ "speculative model." )
50
+
51
+ # Automatically configure the method for ngram when "model" is used
52
+ # instead of "method"
53
+ if self .method is None and (self .model is not None
54
+ and self .model in ("ngram" , "[ngram]" )):
55
+ self .method = "ngram"
56
+
57
+ if self .method in ("ngram" , "[ngram]" ):
58
+ # Unified to "ngram" internally
59
+ self .method = "ngram"
60
+ # Set default values if not provided
61
+ if (self .prompt_lookup_min is None and self .prompt_lookup_max is None ):
62
+ # TODO(woosuk): Tune these values. They are arbitrarily chosen.
63
+ self .prompt_lookup_min = 5
64
+ self .prompt_lookup_max = 5
65
+ elif self .prompt_lookup_min is None :
66
+ assert self .prompt_lookup_max is not None
67
+ self .prompt_lookup_min = self .prompt_lookup_max
68
+ elif self .prompt_lookup_max is None :
69
+ assert self .prompt_lookup_min is not None
70
+ self .prompt_lookup_max = self .prompt_lookup_min
71
+
72
+ # Validate values
73
+ if self .prompt_lookup_min < 1 :
74
+ raise ValueError (
75
+ f"prompt_lookup_min={ self .prompt_lookup_min } must be > 0" )
76
+ if self .prompt_lookup_max < 1 :
77
+ raise ValueError (
78
+ f"prompt_lookup_max={ self .prompt_lookup_max } must be > 0" )
79
+ if self .prompt_lookup_min > self .prompt_lookup_max :
80
+ raise ValueError (
81
+ f"prompt_lookup_min={ self .prompt_lookup_min } must "
82
+ f"be <= prompt_lookup_max={ self .prompt_lookup_max } " )
83
+
84
+ # TODO: current we still need extract vocab_size from target model
85
+ # config, in future, we may try refactor it out, and set
86
+ # draft related config as None here.
87
+ self .draft_model_config = self .target_model_config
88
+ self .draft_parallel_config = self .target_parallel_config
89
+ else :
90
+ self .prompt_lookup_max = 0
91
+ self .prompt_lookup_min = 0
92
+
93
+ if self .model is not None :
94
+ self .draft_model_config = ModelConfig (
95
+ model = self .model ,
96
+ task = "draft" ,
97
+ tokenizer = self .target_model_config .tokenizer ,
98
+ tokenizer_mode = self .target_model_config .tokenizer_mode ,
99
+ trust_remote_code = self .target_model_config .trust_remote_code ,
100
+ allowed_local_media_path = self .target_model_config .
101
+ allowed_local_media_path ,
102
+ dtype = self .target_model_config .dtype ,
103
+ seed = self .target_model_config .seed ,
104
+ revision = self .revision ,
105
+ code_revision = self .code_revision ,
106
+ tokenizer_revision = self .target_model_config .tokenizer_revision ,
107
+ spec_target_max_model_len = self .target_model_config .
108
+ max_model_len ,
109
+ quantization = self .quantization ,
110
+ enforce_eager = self .target_model_config .enforce_eager ,
111
+ max_seq_len_to_capture = self .target_model_config .
112
+ max_seq_len_to_capture ,
113
+ max_logprobs = self .target_model_config .max_logprobs ,
114
+ hf_overrides = SpeculativeConfig .hf_config_override ,
115
+ )
116
+
117
+ # Automatically detect the method
118
+ if self .method in ('eagle' , 'eagle3' ):
119
+ pass
120
+ elif "eagle-" in self .draft_model_config .model .lower () or \
121
+ "eagle3-" in self .draft_model_config .model .lower ():
122
+ self .method = "eagle"
123
+ elif self .draft_model_config .hf_config .model_type == "medusa" :
124
+ self .method = "medusa"
125
+ elif (self .draft_model_config .hf_config .model_type ==
126
+ "mlp_speculator" ):
127
+ self .method = "mlp_speculator"
128
+ elif self .draft_model_config .hf_config .model_type == "deepseek_mtp" :
129
+ self .method = 'mtp'
130
+ else :
131
+ self .method = "draft_model"
132
+
133
+ # Replace hf_config for EAGLE draft_model
134
+ if self .method in ("eagle" , "eagle3" ):
135
+ if self .enable_chunked_prefill and not envs .VLLM_USE_V1 :
136
+ raise ValueError (
137
+ "Chunked prefill and EAGLE are not compatible "
138
+ "when using V0." )
139
+
140
+ from vllm .platforms import current_platform
141
+ from vllm .transformers_utils .configs .eagle import EAGLEConfig
142
+ if isinstance (self .draft_model_config .hf_config ,
143
+ EAGLEConfig ) or current_platform .is_neuron ():
144
+ pass
145
+ else :
146
+ eagle_config = EAGLEConfig (
147
+ self .draft_model_config .hf_config , method = self .method )
148
+ self .draft_model_config .hf_config = eagle_config
149
+
150
+ if (self .num_speculative_tokens is not None
151
+ and hasattr (self .draft_model_config .hf_config ,
152
+ "num_lookahead_tokens" )):
153
+ self .draft_model_config .hf_config .num_lookahead_tokens = \
154
+ self .num_speculative_tokens
155
+
156
+ n_predict = getattr (self .draft_model_config .hf_config , "n_predict" ,
157
+ None )
158
+ if n_predict is not None :
159
+ if self .num_speculative_tokens is None :
160
+ # Default to max value defined in draft model config.
161
+ self .num_speculative_tokens = n_predict
162
+ elif self .num_speculative_tokens > n_predict and \
163
+ self .num_speculative_tokens % n_predict != 0 :
164
+ # Ensure divisibility for MTP module reuse.
165
+ raise ValueError (
166
+ f"num_speculative_tokens:{ self .num_speculative_tokens } "
167
+ f" must be divisible by { n_predict = } " )
168
+
169
+ self .draft_tensor_parallel_size = \
170
+ SpeculativeConfig ._verify_and_get_draft_tp (
171
+ self .target_parallel_config ,
172
+ self .draft_tensor_parallel_size ,
173
+ self .draft_model_config .hf_config
174
+ )
175
+
176
+ self .draft_model_config .max_model_len = (
177
+ SpeculativeConfig ._maybe_override_draft_max_model_len (
178
+ self .max_model_len ,
179
+ self .draft_model_config .max_model_len ,
180
+ self .target_model_config .max_model_len ,
181
+ ))
182
+
183
+ self .draft_parallel_config = (
184
+ SpeculativeConfig .create_draft_parallel_config (
185
+ self .target_parallel_config ,
186
+ self .draft_tensor_parallel_size ))
187
+
188
+ if self .acceptance_method == "typical_acceptance_sampler" :
189
+ if self .posterior_threshold is None :
190
+ self .posterior_threshold = 0.09
191
+ if self .posterior_alpha is None :
192
+ self .posterior_alpha = 0.3
193
+
194
+ self ._verify_args ()
195
+
196
+
197
+ SpeculativeConfig .__post_init__ = __post_init__
0 commit comments