12
12
from vllm .model_executor .model_loader import get_model
13
13
from vllm .model_executor .models import supports_multimodal
14
14
from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
15
+ from vllm .triton_utils import triton
15
16
from vllm .v1 .attention .backends .flash_attn import (CommonAttentionMetadata ,
16
17
FlashAttentionMetadata )
17
18
from vllm .v1 .kv_cache_interface import KVCacheConfig
18
19
from vllm .v1 .sample .metadata import SamplingMetadata
19
- from vllm .v1 .spec_decode .utils import prepare_eagle_input_kernel
20
+ from vllm .v1 .spec_decode .utils import (advance_state_kernel ,
21
+ prepare_eagle_input_kernel )
20
22
21
23
logger = init_logger (__name__ )
22
24
@@ -75,6 +77,14 @@ def __init__(
75
77
device = device ,
76
78
dtype = torch .int32 )
77
79
80
+ # Used to store precomputed values from load_model() so they can be used in propose()
81
+ self .last_token_indices = torch .zeros (self .max_num_tokens ,
82
+ dtype = torch .int32 ,
83
+ device = device )
84
+ self .seq_lens = torch .zeros (self .max_num_tokens ,
85
+ dtype = torch .int32 ,
86
+ device = device )
87
+
78
88
def propose (
79
89
self ,
80
90
# [num_tokens]
@@ -92,40 +102,21 @@ def propose(
92
102
# [batch_size, max_num_blocks_per_req]
93
103
block_table : torch .Tensor ,
94
104
sampling_metadata : SamplingMetadata ,
105
+ num_tokens : int ,
106
+ max_num_tokens : int ,
107
+ max_seq_len : int ,
95
108
) -> torch .Tensor :
96
- num_tokens = target_token_ids .shape [0 ]
97
109
batch_size = next_token_ids .shape [0 ]
98
- last_token_indices = cu_num_tokens [1 :] - 1
99
-
100
- if self .method == "eagle3" :
101
- assert isinstance (self .model , Eagle3LlamaForCausalLM )
102
- target_hidden_states = self .model .combine_hidden_states (
103
- target_hidden_states )
104
- assert target_hidden_states .shape [- 1 ] == self .hidden_size
105
-
106
- # Shift the input ids by one token.
107
- # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
108
- self .input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
109
- # Replace the last token with the next token.
110
- # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
111
- self .input_ids [last_token_indices ] = next_token_ids
112
-
113
- # FA requires seq_len to have dtype int32.
114
- seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
115
110
116
111
if self .method in ["eagle" , "eagle3" ]:
117
- # FIXME(woosuk): The below two ops cause synchronization. Optimize.
118
- max_seq_len = seq_lens .max ().item ()
119
- max_num_tokens = (cu_num_tokens [1 :] -
120
- cu_num_tokens [:- 1 ]).max ().item ()
121
112
attn_metadata = FlashAttentionMetadata (
122
113
num_actual_tokens = num_tokens ,
123
114
max_query_len = max_num_tokens ,
124
115
query_start_loc = cu_num_tokens ,
125
116
max_seq_len = max_seq_len ,
126
- seq_lens = seq_lens ,
117
+ seq_lens = self . seq_lens ,
127
118
block_table = block_table ,
128
- slot_mapping = target_slot_mapping ,
119
+ slot_mapping = target_slot_mapping [: num_tokens ] ,
129
120
# TODO(woosuk): Support cascade attention.
130
121
use_cascade = False ,
131
122
common_prefix_len = 0 ,
@@ -134,15 +125,12 @@ def propose(
134
125
suffix_kv_lens = None ,
135
126
)
136
127
elif self .method == "deepseek_mtp" :
137
- query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
138
- max_query_len = query_lens .max ().item ()
139
-
140
128
common_attn_metadata = CommonAttentionMetadata (
141
129
query_start_loc = cu_num_tokens ,
142
- seq_lens = seq_lens ,
130
+ seq_lens = self . seq_lens ,
143
131
num_reqs = batch_size ,
144
132
num_actual_tokens = num_tokens ,
145
- max_query_len = max_query_len ,
133
+ max_query_len = self . max_num_tokens ,
146
134
)
147
135
148
136
assert self .runner is not None
@@ -165,9 +153,6 @@ def propose(
165
153
num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
166
154
else :
167
155
num_input_tokens = num_tokens
168
- # copy inputs to buffer for cudagraph
169
- self .positions [:num_tokens ] = target_positions
170
- self .hidden_states [:num_tokens ] = target_hidden_states
171
156
172
157
with set_forward_context (per_layer_attn_metadata ,
173
158
self .vllm_config ,
@@ -181,7 +166,7 @@ def propose(
181
166
last_hidden_states = ret_hidden_states
182
167
else :
183
168
last_hidden_states , hidden_states = ret_hidden_states
184
- sample_hidden_states = last_hidden_states [last_token_indices ]
169
+ sample_hidden_states = last_hidden_states [self . last_token_indices ]
185
170
logits = self .model .compute_logits (sample_hidden_states , None )
186
171
draft_token_ids = logits .argmax (dim = - 1 )
187
172
@@ -197,8 +182,8 @@ def propose(
197
182
# Generate the remaining draft tokens.
198
183
draft_token_ids_list = [draft_token_ids ]
199
184
200
- positions = target_positions [last_token_indices ]
201
- hidden_states = hidden_states [last_token_indices ]
185
+ positions = target_positions [self . last_token_indices ]
186
+ hidden_states = hidden_states [self . last_token_indices ]
202
187
if self .use_cuda_graph and \
203
188
batch_size <= self .cudagraph_batch_sizes [- 1 ]:
204
189
input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
@@ -208,52 +193,12 @@ def propose(
208
193
attn_metadata .max_query_len = 1
209
194
attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
210
195
for _ in range (self .num_speculative_tokens - 1 ):
211
- # Update the inputs.
212
- # cast to int32 is crucial when eagle model is compiled.
213
- # tensor.argmax() returns int64 by default.
214
- input_ids = draft_token_ids_list [- 1 ].int ()
215
- positions += 1
216
-
217
- # NOTE(woosuk): We should handle the case where the draft model
218
- # generates tokens beyond the max model length. Since it is complex
219
- # to remove such requests from the batch, we keep them in the batch
220
- # but adjust the position ids and slot mappings to avoid the
221
- # out-of-range access during the model execution. The draft tokens
222
- # generated with this adjustment should be ignored.
223
- exceeds_max_model_len = positions >= self .max_model_len
224
- # Mask out the position ids that exceed the max model length.
225
- # Otherwise, we may get out-of-range error in RoPE.
226
- clamped_positions = torch .where (exceeds_max_model_len , 0 ,
227
- positions )
228
-
229
- # Increment the sequence lengths.
230
- attn_metadata .max_seq_len += 1
231
- attn_metadata .seq_lens += 1
232
- # Consider max model length.
233
- attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
234
- self .max_model_len )
235
- # For the requests that exceed the max model length, we set the
236
- # sequence length to 1 to minimize their overheads in attention.
237
- attn_metadata .seq_lens .masked_fill_ (exceeds_max_model_len , 1 )
238
-
239
- # Compute the slot mapping.
240
- block_numbers = clamped_positions // self .block_size
241
- block_ids = block_table .gather (dim = 1 ,
242
- index = block_numbers .view (- 1 , 1 ))
243
- block_ids = block_ids .view (- 1 )
244
- attn_metadata .slot_mapping = (block_ids * self .block_size +
245
- clamped_positions % self .block_size )
246
- # Mask out the slot mappings that exceed the max model length.
247
- # Otherwise, the KV cache will be inadvertently updated with the
248
- # padding tokens.
249
- attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
250
- PADDING_SLOT_ID )
251
196
252
- # copy inputs to buffer for cudagraph
253
- self .input_ids [:batch_size ] = input_ids
254
- self .positions [:batch_size ] = clamped_positions
255
- self .hidden_states [:batch_size ] = hidden_states
197
+ self .advance_speculative_state (draft_token_ids_list [- 1 ], positions ,
198
+ hidden_states , attn_metadata ,
199
+ batch_size )
256
200
201
+ # copy inputs to buffer for cudagraph
257
202
# Run the model.
258
203
with set_forward_context (per_layer_attn_metadata ,
259
204
self .vllm_config ,
@@ -275,6 +220,58 @@ def propose(
275
220
draft_token_ids = torch .stack (draft_token_ids_list , dim = 1 )
276
221
return draft_token_ids
277
222
223
+ def advance_speculative_state (self , draft_token_ids : torch .Tensor ,
224
+ positions : torch .Tensor ,
225
+ hidden_states : torch .Tensor ,
226
+ attn_metadata : FlashAttentionMetadata ,
227
+ batch_size : int ):
228
+ """
229
+ Advances the speculative decoding state and metadata by one step
230
+
231
+ Parameters:
232
+ ----------
233
+ draft_token_ids (torch.Tensor): Token IDs generated by the draft model
234
+ positions (torch.Tensor): Position indices for the draft tokens
235
+ hidden_states (torch.Tensor): Corresponding hidden states for the tokens
236
+ attn_metadata (FlashAttentionMetadata): Metadata required for FlashAttention (e.g., sequence lengths, block table).
237
+ batch_size (int): Number of sequences to update.
238
+ """
239
+
240
+ # Calculate number of thread blocks
241
+ grid = lambda meta : (triton .cdiv (batch_size , meta ['BLOCK_SIZE' ]), )
242
+ attn_metadata .slot_mapping = torch .empty_like (positions )
243
+ advance_state_kernel [grid ](
244
+ # === Input tensors ===
245
+ draft_token_ids ,
246
+ positions ,
247
+
248
+ # === Model input buffers to be updated ===
249
+ self .input_ids [:batch_size ],
250
+ self .positions [:batch_size ],
251
+
252
+ # === Metadata tensors ===
253
+ attn_metadata .seq_lens ,
254
+ attn_metadata .block_table ,
255
+ attn_metadata .slot_mapping ,
256
+
257
+ # === Scalar configuration ===
258
+ self .max_model_len ,
259
+ self .block_size ,
260
+ self .max_model_len // self .block_size ,
261
+
262
+ # === Execution control ===
263
+ batch_size ,
264
+ BLOCK_SIZE = 1024 ,
265
+ PADDING_SLOT_ID = PADDING_SLOT_ID )
266
+
267
+ self .hidden_states [:batch_size ] = hidden_states
268
+
269
+ # Increment the sequence lengths.
270
+ attn_metadata .max_seq_len += 1
271
+ # Consider max model length.
272
+ attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
273
+ self .max_model_len )
274
+
278
275
@staticmethod
279
276
def prepare_inputs (
280
277
# [batch_size + 1]
@@ -301,7 +298,7 @@ def prepare_inputs(
301
298
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
302
299
cu_num_tokens = torch .zeros_like (cu_target_query_lens )
303
300
torch .cumsum (num_tokens_per_req , dim = 0 , out = cu_num_tokens [1 :])
304
- token_indices = torch .empty (
301
+ token_indices = torch .zeros (
305
302
num_tokens ,
306
303
dtype = torch .int32 ,
307
304
device = cu_target_query_lens .device ,
@@ -316,6 +313,54 @@ def prepare_inputs(
316
313
)
317
314
return cu_num_tokens , token_indices
318
315
316
+ def load_inputs (self , target_token_ids : torch .Tensor ,
317
+ target_positions : torch .Tensor ,
318
+ target_hidden_states : torch .Tensor ,
319
+ next_token_ids_gpu : torch .Tensor ,
320
+ cu_num_tokens : torch .Tensor , num_scheduled_tokens : int ):
321
+ """
322
+ Loads token ids, positions, etc. into the eagle model
323
+
324
+ Logic moved from EagleProposer.propose() to here
325
+
326
+ Parameters:
327
+ ----------
328
+ target_token_ids (torch.Tensor): Draft-step token IDs
329
+ target_positions (torch.Tensor): Position indices for the tokens
330
+ target_hidden_states (torch.Tensor): Corresponding hidden states for the tokens
331
+ next_token_ids_gpu (torch.Tensor): Sampled next token IDs to overwrite final token
332
+ cu_num_tokens (torch.Tensor): Cumulative number of tokens from prepare_inputs()
333
+ num_scheduled_tokens (int): Total number of tokens scheduled
334
+ """
335
+
336
+ self .last_token_indices = cu_num_tokens [1 :] - 1
337
+
338
+ # FA requires seq_len to have dtype int32.
339
+ self .seq_lens = (target_positions [self .last_token_indices ] + 1 ).int ()
340
+
341
+ if self .method == "eagle3" :
342
+ assert isinstance (self .model , Eagle3LlamaForCausalLM )
343
+ target_hidden_states = self .model .combine_hidden_states (
344
+ target_hidden_states )
345
+ assert target_hidden_states .shape [- 1 ] == self .hidden_size
346
+
347
+ # Shift the input ids by one token.
348
+ # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
349
+ self .input_ids [:num_scheduled_tokens -
350
+ 1 ] = target_token_ids [:num_scheduled_tokens ][1 :]
351
+
352
+ # Replace the last token with the next token.
353
+ # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
354
+ self .input_ids [self .last_token_indices ] = next_token_ids_gpu
355
+
356
+ # copy inputs to buffer for cudagraph
357
+ self .positions [:
358
+ num_scheduled_tokens ] = target_positions [:
359
+ num_scheduled_tokens ]
360
+ self .hidden_states [:
361
+ num_scheduled_tokens ] = target_hidden_states [:
362
+ num_scheduled_tokens ]
363
+
319
364
def load_model (self , target_model : nn .Module ) -> None :
320
365
draft_model_config = \
321
366
self .vllm_config .speculative_config .draft_model_config
0 commit comments