1
+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar , Union
2
+
3
+ import vllm .envs as envs
4
+ from transformers import PretrainedConfig
5
+
6
+ from vllm .config import ModelConfig , SpeculativeConfig
7
+
8
+ if TYPE_CHECKING :
9
+ from _typeshed import DataclassInstance
10
+
11
+ ConfigType = type [DataclassInstance ]
12
+ else :
13
+ ConfigType = type
14
+
15
+ ConfigT = TypeVar ("ConfigT" , bound = ConfigType )
16
+
17
+ TaskOption = Literal ["auto" , "generate" , "embedding" , "embed" , "classify" ,
18
+ "score" , "reward" , "transcription" ]
19
+
20
+ RunnerType = Literal ["generate" , "pooling" , "draft" , "transcription" ]
21
+
22
+ HfOverrides = Union [dict [str , Any ], Callable [[PretrainedConfig ],
23
+ PretrainedConfig ]]
24
+
25
+
26
+ def __post_init__ (self ):
27
+
28
+ # Note: "method" is a new parameter that helps to extend the
29
+ # configuration of non-model-based proposers, and the "model" parameter
30
+ # will be used to set the draft model, eagle head, or additional weight
31
+ # when needed. If users do not specify "method", the speculative method
32
+ # will be detected automatically if possible. If the speculative method
33
+ # can not be detected, it will be considered as the "draft_model" by
34
+ # default.
35
+
36
+ if self .model is None and self .num_speculative_tokens is not None :
37
+ # TODO(Shangming): Refactor mtp configuration logic when supporting
38
+ # mtp acceleration for more models besides deepseek_v3
39
+ if self .target_model_config and \
40
+ (self .target_model_config .hf_text_config .model_type \
41
+ == "deepseek_v3" or
42
+ self .target_model_config .hf_text_config .model_type \
43
+ == "mimo" ):
44
+ # use the draft model from the same model:
45
+ self .model = self .target_model_config .model
46
+ elif self .method in ("ngram" , "[ngram]" ):
47
+ self .model = "ngram"
48
+ else :
49
+ raise ValueError ("num_speculative_tokens was provided without "
50
+ "speculative model." )
51
+
52
+ # Automatically configure the method for ngram when "model" is used
53
+ # instead of "method"
54
+ if self .method is None and (self .model is not None
55
+ and self .model in ("ngram" , "[ngram]" )):
56
+ self .method = "ngram"
57
+
58
+ if self .method in ("ngram" , "[ngram]" ):
59
+ # Unified to "ngram" internally
60
+ self .method = "ngram"
61
+ # Set default values if not provided
62
+ if (self .prompt_lookup_min is None and self .prompt_lookup_max is None ):
63
+ # TODO(woosuk): Tune these values. They are arbitrarily chosen.
64
+ self .prompt_lookup_min = 5
65
+ self .prompt_lookup_max = 5
66
+ elif self .prompt_lookup_min is None :
67
+ assert self .prompt_lookup_max is not None
68
+ self .prompt_lookup_min = self .prompt_lookup_max
69
+ elif self .prompt_lookup_max is None :
70
+ assert self .prompt_lookup_min is not None
71
+ self .prompt_lookup_max = self .prompt_lookup_min
72
+
73
+ # Validate values
74
+ if self .prompt_lookup_min < 1 :
75
+ raise ValueError (
76
+ f"prompt_lookup_min={ self .prompt_lookup_min } must be > 0" )
77
+ if self .prompt_lookup_max < 1 :
78
+ raise ValueError (
79
+ f"prompt_lookup_max={ self .prompt_lookup_max } must be > 0" )
80
+ if self .prompt_lookup_min > self .prompt_lookup_max :
81
+ raise ValueError (
82
+ f"prompt_lookup_min={ self .prompt_lookup_min } must "
83
+ f"be <= prompt_lookup_max={ self .prompt_lookup_max } " )
84
+
85
+ # TODO: current we still need extract vocab_size from target model
86
+ # config, in future, we may try refactor it out, and set
87
+ # draft related config as None here.
88
+ self .draft_model_config = self .target_model_config
89
+ self .draft_parallel_config = self .target_parallel_config
90
+ else :
91
+ self .prompt_lookup_max = 0
92
+ self .prompt_lookup_min = 0
93
+
94
+ if self .model is not None :
95
+ self .draft_model_config = ModelConfig (
96
+ model = self .model ,
97
+ task = "draft" ,
98
+ tokenizer = self .target_model_config .tokenizer ,
99
+ tokenizer_mode = self .target_model_config .tokenizer_mode ,
100
+ trust_remote_code = self .target_model_config .trust_remote_code ,
101
+ allowed_local_media_path = self .target_model_config .
102
+ allowed_local_media_path ,
103
+ dtype = self .target_model_config .dtype ,
104
+ seed = self .target_model_config .seed ,
105
+ revision = self .revision ,
106
+ code_revision = self .code_revision ,
107
+ tokenizer_revision = self .target_model_config .tokenizer_revision ,
108
+ spec_target_max_model_len = self .target_model_config .
109
+ max_model_len ,
110
+ quantization = self .quantization ,
111
+ enforce_eager = self .target_model_config .enforce_eager ,
112
+ max_seq_len_to_capture = self .target_model_config .
113
+ max_seq_len_to_capture ,
114
+ max_logprobs = self .target_model_config .max_logprobs ,
115
+ hf_overrides = SpeculativeConfig .hf_config_override ,
116
+ )
117
+
118
+ # Automatically detect the method
119
+ if self .method in ('eagle' , 'eagle3' ):
120
+ pass
121
+ elif "eagle-" in self .draft_model_config .model .lower () or \
122
+ "eagle3-" in self .draft_model_config .model .lower ():
123
+ self .method = "eagle"
124
+ elif self .draft_model_config .hf_config .model_type == "medusa" :
125
+ self .method = "medusa"
126
+ elif (self .draft_model_config .hf_config .model_type ==
127
+ "mlp_speculator" ):
128
+ self .method = "mlp_speculator"
129
+ elif self .draft_model_config .hf_config .model_type == "deepseek_mtp" :
130
+ self .method = 'mtp'
131
+ else :
132
+ self .method = "draft_model"
133
+
134
+ # Replace hf_config for EAGLE draft_model
135
+ if self .method in ("eagle" , "eagle3" ):
136
+ if self .enable_chunked_prefill and not envs .VLLM_USE_V1 :
137
+ raise ValueError (
138
+ "Chunked prefill and EAGLE are not compatible "
139
+ "when using V0." )
140
+
141
+ from vllm .platforms import current_platform
142
+ from vllm .transformers_utils .configs .eagle import EAGLEConfig
143
+ if isinstance (self .draft_model_config .hf_config ,
144
+ EAGLEConfig ) or current_platform .is_neuron ():
145
+ pass
146
+ else :
147
+ eagle_config = EAGLEConfig (
148
+ self .draft_model_config .hf_config , method = self .method )
149
+ self .draft_model_config .hf_config = eagle_config
150
+
151
+ if (self .num_speculative_tokens is not None
152
+ and hasattr (self .draft_model_config .hf_config ,
153
+ "num_lookahead_tokens" )):
154
+ self .draft_model_config .hf_config .num_lookahead_tokens = \
155
+ self .num_speculative_tokens
156
+
157
+ n_predict = getattr (self .draft_model_config .hf_config , "n_predict" ,
158
+ None )
159
+ if n_predict is not None :
160
+ if self .num_speculative_tokens is None :
161
+ # Default to max value defined in draft model config.
162
+ self .num_speculative_tokens = n_predict
163
+ elif self .num_speculative_tokens > n_predict and \
164
+ self .num_speculative_tokens % n_predict != 0 :
165
+ # Ensure divisibility for MTP module reuse.
166
+ raise ValueError (
167
+ f"num_speculative_tokens:{ self .num_speculative_tokens } "
168
+ f" must be divisible by { n_predict = } " )
169
+
170
+ self .draft_tensor_parallel_size = \
171
+ SpeculativeConfig ._verify_and_get_draft_tp (
172
+ self .target_parallel_config ,
173
+ self .draft_tensor_parallel_size ,
174
+ self .draft_model_config .hf_config
175
+ )
176
+
177
+ self .draft_model_config .max_model_len = (
178
+ SpeculativeConfig ._maybe_override_draft_max_model_len (
179
+ self .max_model_len ,
180
+ self .draft_model_config .max_model_len ,
181
+ self .target_model_config .max_model_len ,
182
+ ))
183
+
184
+ self .draft_parallel_config = (
185
+ SpeculativeConfig .create_draft_parallel_config (
186
+ self .target_parallel_config ,
187
+ self .draft_tensor_parallel_size ))
188
+
189
+ if self .acceptance_method == "typical_acceptance_sampler" :
190
+ if self .posterior_threshold is None :
191
+ self .posterior_threshold = 0.09
192
+ if self .posterior_alpha is None :
193
+ self .posterior_alpha = 0.3
194
+
195
+ self ._verify_args ()
196
+
197
+
198
+ SpeculativeConfig .__post_init__ = __post_init__
0 commit comments