@@ -104,6 +104,7 @@ def _prepare_adjusted_tensors(
104
104
cu_num_tokens : torch .Tensor ,
105
105
decode_mask : torch .Tensor ,
106
106
full_prefill_mask : torch .Tensor ,
107
+ partial_prefill_mask : torch .Tensor ,
107
108
prefill_first_hiddens : torch .Tensor ,
108
109
block_table : torch .Tensor ,
109
110
batch_size : int ,
@@ -131,6 +132,34 @@ def _prepare_adjusted_tensors(
131
132
tuple: (target_positions, target_hidden_states, target_slot_mapping,
132
133
cu_num_tokens, current_pos, partial_prefill_mask)
133
134
135
+ Algorithm design:
136
+ - Suppose target tokens are [1,2,3,...N], next token is N+1
137
+ - Position is [0,1,2,...N-1]
138
+ - And hidden is [h1,h2,h3,...hN]
139
+ - Suppose partial prefill is [Nm, Nm+1, ...Nm+M-1]
140
+ -- For normal shifting:
141
+ --- draft prefill is [2,3,...N+1], position is same as target
142
+ --- Stacking hidden is [h1,h2,h3,...hN]
143
+ --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...]
144
+ --- Decode positions are [N,N+1,...]
145
+ --- draft partial prefill is [Nm+1, Nm+2, ...Nm+M]
146
+ -- For non-shifting:
147
+ --- draft full prefill is [1,2,3,...N+1], position is [0,1,2,...N]
148
+ --- Stacking hidden is [hN,h1,h2,h3,...hN]
149
+ --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...]
150
+ --- Decode positions are [N+1,N+2,...]
151
+ --- draft partial prefill is [Nm, Nm+1, ...Nm+M-1]
152
+ --- draft hidden is [hNm-1,hNm,...hNm+M]
153
+ (hNm-1 is the last round hidden)
154
+ -- For kv sharing(non-shifting required):
155
+ This means all target prefill tokens are not needed to be processed
156
+ in drafting prefill step as we don't need the kv from draft.
157
+ --- draft full prefill is [N+1], position is [N]
158
+ --- Stacking hidden is [hN]
159
+ --- Decode is the same as non-shifting decode
160
+ --- draft partial prefill is totally skipped
161
+ All other metadata like slot mapping, etc. should be based on
162
+ the positions and tokens to generate/manipulate again
134
163
"""
135
164
# Count total number of full prefill requests to determine the
136
165
# size needed for adjusted tensors
@@ -184,21 +213,6 @@ def _prepare_adjusted_tensors(
184
213
# Create updated cumulative token counts
185
214
updated_cu_num_tokens = torch .zeros_like (cu_num_tokens )
186
215
187
- # Track which requests are partial prefill (no decode tokens)
188
- partial_prefill_mask = torch .zeros_like (full_prefill_mask )
189
-
190
- # Create masks for each category
191
- has_decode_mask = torch .zeros (batch_size ,
192
- dtype = torch .bool ,
193
- device = decode_mask .device )
194
- for i in range (batch_size ):
195
- start_idx = cu_num_tokens [i ].item ()
196
- end_idx = cu_num_tokens [i + 1 ].item ()
197
- has_decode_mask [i ] = decode_mask [start_idx :end_idx ].any ().item ()
198
-
199
- # Category 1: Partial prefill (no decode tokens)
200
- partial_prefill_mask = ~ has_decode_mask
201
-
202
216
# Process batched operations using masks
203
217
current_pos = 0
204
218
cu_num_tokens_index = 0
@@ -368,6 +382,7 @@ def propose(
368
382
mm_embeds : Optional [list [torch .Tensor ]] = None ,
369
383
decode_mask : torch .Tensor = None ,
370
384
full_prefill_mask : torch .Tensor = None ,
385
+ partial_prefill_mask : torch .Tensor = None ,
371
386
) -> torch .Tensor :
372
387
num_tokens = target_token_ids .shape [0 ]
373
388
batch_size = next_token_ids .shape [0 ]
@@ -388,6 +403,17 @@ def propose(
388
403
prefill_shift_tokens = False
389
404
390
405
if not prefill_shift_tokens and has_prefill :
406
+ if (partial_prefill_mask .all ()
407
+ and self .draft_prefill_kv_sharing_from_base ):
408
+ # All requests are partial prefill and
409
+ # KV cache sharing is enabled
410
+ # Skip the rest of the function
411
+ # and return dummy draft tokens
412
+ return torch .zeros (
413
+ (batch_size , self .num_speculative_tokens ),
414
+ dtype = target_token_ids .dtype ,
415
+ device = target_token_ids .device ,
416
+ )
391
417
# Adjust the tensors for full prefill requests
392
418
(
393
419
target_positions ,
@@ -404,22 +430,12 @@ def propose(
404
430
cu_num_tokens ,
405
431
decode_mask ,
406
432
full_prefill_mask ,
433
+ partial_prefill_mask ,
407
434
prefill_first_hiddens ,
408
435
block_table ,
409
436
batch_size ,
410
437
num_tokens ,
411
438
)
412
- if (partial_prefill_mask .all ()
413
- and self .draft_prefill_kv_sharing_from_base ):
414
- # All requests are partial prefill and
415
- # KV cache sharing is enabled
416
- # Skip the rest of the function
417
- # and return dummy draft tokens
418
- return torch .zeros (
419
- (batch_size , self .num_speculative_tokens ),
420
- dtype = target_token_ids .dtype ,
421
- device = target_token_ids .device ,
422
- )
423
439
batch_size = cu_num_tokens .shape [0 ] - 1
424
440
else :
425
441
# Original behavior: shift all tokens by one
@@ -451,6 +467,9 @@ def propose(
451
467
if not prefill_shift_tokens and has_prefill :
452
468
# Replace the last token with the next token under non-shifting,
453
469
# but only for non-partial prefill requests
470
+ # For partial prefill in non-shifting, we just match the target
471
+ # prefill tokens as it would match the positions and hidden states
472
+ # so no need to add this next token from next round
454
473
mask = ~ partial_prefill_mask
455
474
# if we enable copy kv then all of the partial prefills
456
475
# are completely skipped so they won't be in last_token_indices
0 commit comments