@@ -528,19 +528,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
528
528
start_token_index :end_token_index ] = new_token_ids
529
529
self .input_batch .num_tokens_no_spec [
530
530
req_index ] = end_token_index
531
- # Add spec_token_ids to token_ids_cpu.
532
- spec_token_ids = (
533
- scheduler_output .scheduled_spec_decode_tokens .get (
534
- req_id , ()))
535
- if spec_token_ids :
536
- start_index = end_token_index
537
- end_token_index += len (spec_token_ids )
538
- self .input_batch .token_ids_cpu [
539
- req_index ,
540
- start_index :end_token_index ] = spec_token_ids
541
- # NOTE(woosuk): `num_tokens` here may include spec tokens.
542
531
self .input_batch .num_tokens [req_index ] = end_token_index
543
532
533
+ # Add spec_token_ids to token_ids_cpu.
534
+ spec_token_ids = (
535
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , ()))
536
+ if spec_token_ids :
537
+ num_spec_tokens = len (spec_token_ids )
538
+ start_index = self .input_batch .num_tokens_no_spec [req_index ]
539
+ end_token_index = start_index + num_spec_tokens
540
+ self .input_batch .token_ids_cpu [
541
+ req_index , start_index :end_token_index ] = spec_token_ids
542
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
543
+ self .input_batch .num_tokens [req_index ] += num_spec_tokens
544
+
544
545
# Add the new or resumed requests to the persistent batch.
545
546
# The smaller empty indices are filled first.
546
547
for req_id in req_ids_to_add :
0 commit comments