@@ -456,78 +456,147 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
456
456
req_ids_to_add .append (req_id )
457
457
458
458
# Update the states of the running/resumed requests.
459
- for req_data in scheduler_output .scheduled_cached_reqs :
460
- req_id = req_data .req_id
461
- req_state = self .requests [req_id ]
459
+ if vllm_version_is ("0.9.1" ):
460
+ for req_data in scheduler_output .scheduled_cached_reqs :
461
+ req_id = req_data .req_id
462
+ req_state = self .requests [req_id ]
462
463
463
- # Update the cached states.
464
- num_computed_tokens = req_data .num_computed_tokens
465
- req_state .num_computed_tokens = num_computed_tokens
466
- # Add the sampled token(s) from the previous step (if any).
467
- # This doesn't include "unverified" tokens like spec decode tokens.
468
- num_new_tokens = (num_computed_tokens +
469
- len (req_data .new_token_ids ) -
470
- req_state .num_tokens )
471
- if num_new_tokens == 1 :
472
- # Avoid slicing list in most common case.
473
- req_state .output_token_ids .append (req_data .new_token_ids [- 1 ])
474
- elif num_new_tokens > 0 :
475
- req_state .output_token_ids .extend (
476
- req_data .new_token_ids [- num_new_tokens :])
477
- # Update the block IDs.
478
- if not req_data .resumed_from_preemption :
479
- # Append the new blocks to the existing block IDs.
480
- for block_ids , new_block_ids in zip ( # type: ignore[call-overload]
481
- req_state .block_ids ,
482
- req_data .new_block_ids ,
483
- strict = True ):
484
- block_ids .extend (new_block_ids )
485
- else :
486
- # The request is resumed from preemption.
487
- # Replace the existing block IDs with the new ones.
488
- req_state .block_ids = req_data .new_block_ids
489
-
490
- req_index = self .input_batch .req_id_to_index .get (req_id )
491
- if req_index is None :
492
- # The request is not in the persistent batch.
493
- # The request was either preempted and resumed later, or was not
494
- # scheduled in the previous step and needs to be added again.
495
- req_ids_to_add .append (req_id )
496
- continue
464
+ # Update the cached states.
465
+ num_computed_tokens = req_data .num_computed_tokens
466
+ req_state .num_computed_tokens = num_computed_tokens
467
+ # Add the sampled token(s) from the previous step (if any).
468
+ # This doesn't include "unverified" tokens like spec decode tokens.
469
+ num_new_tokens = (num_computed_tokens +
470
+ len (req_data .new_token_ids ) -
471
+ req_state .num_tokens )
472
+ if num_new_tokens == 1 :
473
+ # Avoid slicing list in most common case.
474
+ req_state .output_token_ids .append (
475
+ req_data .new_token_ids [- 1 ])
476
+ elif num_new_tokens > 0 :
477
+ req_state .output_token_ids .extend (
478
+ req_data .new_token_ids [- num_new_tokens :])
479
+ # Update the block IDs.
480
+ if not req_data .resumed_from_preemption :
481
+ # Append the new blocks to the existing block IDs.
482
+ for block_ids , new_block_ids in zip ( # type: ignore[call-overload]
483
+ req_state .block_ids ,
484
+ req_data .new_block_ids ,
485
+ strict = True ):
486
+ block_ids .extend (new_block_ids )
487
+ else :
488
+ # The request is resumed from preemption.
489
+ # Replace the existing block IDs with the new ones.
490
+ req_state .block_ids = req_data .new_block_ids
491
+
492
+ req_index = self .input_batch .req_id_to_index .get (req_id )
493
+ if req_index is None :
494
+ # The request is not in the persistent batch.
495
+ # The request was either preempted and resumed later, or was not
496
+ # scheduled in the previous step and needs to be added again.
497
+ req_ids_to_add .append (req_id )
498
+ continue
499
+
500
+ # Update the persistent batch.
501
+ self .input_batch .num_computed_tokens_cpu [req_index ] = (
502
+ num_computed_tokens )
503
+
504
+ start_index = (len (req_state .block_ids ) -
505
+ len (req_data .new_block_ids ))
506
+ self .input_batch .block_table .append_row (
507
+ req_data .new_block_ids , req_index )
508
+ # Add new_token_ids to token_ids_cpu.
509
+ start_token_index = num_computed_tokens
510
+ end_token_index = num_computed_tokens + len (
511
+ req_data .new_token_ids )
512
+ self .input_batch .token_ids_cpu [
513
+ req_index ,
514
+ start_token_index :end_token_index ] = req_data .new_token_ids
515
+ self .input_batch .num_tokens_no_spec [
516
+ req_index ] = end_token_index
517
+ # Add spec_token_ids to token_ids_cpu.
518
+ spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
519
+ req_id , ())
520
+ if spec_token_ids :
521
+ start_index = end_token_index
522
+ end_token_index += len (spec_token_ids )
523
+ self .input_batch .token_ids_cpu [
524
+ req_index ,
525
+ start_index :end_token_index ] = spec_token_ids
526
+ # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
527
+ self .input_batch .num_tokens [req_index ] = end_token_index
528
+ else :
529
+ req_data = scheduler_output .scheduled_cached_reqs
530
+ for i , req_id in enumerate (req_data .req_ids ):
531
+ req_state = self .requests [req_id ]
532
+ num_computed_tokens = req_data .num_computed_tokens [i ]
533
+ new_token_ids = req_data .new_token_ids [i ]
534
+ new_block_ids = req_data .new_block_ids [i ]
535
+ resumed_from_preemption = req_data .resumed_from_preemption [i ]
536
+
537
+ req_state .num_computed_tokens = num_computed_tokens
538
+ # Add the sampled token(s) from the previous step (if any).
539
+ # This doesn't include "unverified" tokens like spec decode tokens.
540
+ num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
541
+ req_state .num_tokens )
542
+ if num_new_tokens == 1 :
543
+ # Avoid slicing list in most common case.
544
+ req_state .output_token_ids .append (new_token_ids [- 1 ])
545
+ elif num_new_tokens > 0 :
546
+ req_state .output_token_ids .extend (
547
+ new_token_ids [- num_new_tokens :])
548
+ # Update the block IDs.
549
+ if not resumed_from_preemption :
550
+ # Append the new blocks to the existing block IDs.
551
+ for block_ids , new_ids in zip ( # type: ignore[call-overload]
552
+ req_state .block_ids , new_block_ids ):
553
+ block_ids .extend (new_ids )
554
+ else :
555
+ # The request is resumed from preemption.
556
+ # Replace the existing block IDs with the new ones.
557
+ req_state .block_ids = new_block_ids
558
+
559
+ req_index = self .input_batch .req_id_to_index .get (req_id )
560
+ if req_index is None :
561
+ # The request is not in the persistent batch.
562
+ # The request was either preempted and resumed later, or was not
563
+ # scheduled in the previous step and needs to be added again.
564
+ req_ids_to_add .append (req_id )
565
+ continue
566
+
567
+ # Update the persistent batch.
568
+ self .input_batch .num_computed_tokens_cpu [req_index ] = (
569
+ num_computed_tokens )
497
570
498
- # Update the persistent batch.
499
- self .input_batch .num_computed_tokens_cpu [req_index ] = (
500
- num_computed_tokens )
501
-
502
- start_index = (len (req_state .block_ids ) -
503
- len (req_data .new_block_ids ))
504
- self .input_batch .block_table .append_row (req_data .new_block_ids ,
505
- req_index )
506
- # Add new_token_ids to token_ids_cpu.
507
- start_token_index = num_computed_tokens
508
- end_token_index = num_computed_tokens + len (req_data .new_token_ids )
509
- self .input_batch .token_ids_cpu [
510
- req_index ,
511
- start_token_index :end_token_index ] = req_data .new_token_ids
512
- self .input_batch .num_tokens_no_spec [req_index ] = end_token_index
513
- # Add spec_token_ids to token_ids_cpu.
514
- spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
515
- req_id , ())
516
- if spec_token_ids :
517
- start_index = end_token_index
518
- end_token_index += len (spec_token_ids )
571
+ self .input_batch .block_table .append_row (
572
+ new_block_ids , req_index )
573
+ # Add new_token_ids to token_ids_cpu.
574
+ start_token_index = num_computed_tokens
575
+ end_token_index = num_computed_tokens + len (new_token_ids )
519
576
self .input_batch .token_ids_cpu [
520
- req_index , start_index :end_token_index ] = spec_token_ids
521
- # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
522
- self .input_batch .num_tokens [req_index ] = end_token_index
577
+ req_index ,
578
+ start_token_index :end_token_index ] = new_token_ids
579
+ self .input_batch .num_tokens_no_spec [
580
+ req_index ] = end_token_index
581
+ # Add spec_token_ids to token_ids_cpu.
582
+ spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
583
+ req_id , ())
584
+ if spec_token_ids :
585
+ start_index = end_token_index
586
+ end_token_index += len (spec_token_ids )
587
+ self .input_batch .token_ids_cpu [
588
+ req_index ,
589
+ start_index :end_token_index ] = spec_token_ids
590
+ # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
591
+ self .input_batch .num_tokens [req_index ] = end_token_index
523
592
524
593
# Check if the batch has changed. If not, we can skip copying the
525
594
# sampling metadata from CPU to GPU.
526
595
batch_changed = len (removed_req_indices ) > 0 or len (req_ids_to_add ) > 0
527
596
528
597
# Add the new or resumed requests to the persistent batch.
529
598
# The smaller empty indices are filled first.
530
- removed_req_indices = sorted ( removed_req_indices , reverse = True )
599
+ removed_req_indices . sort ( reverse = True )
531
600
for req_id in req_ids_to_add :
532
601
req_state = self .requests [req_id ]
533
602
if removed_req_indices :
0 commit comments