@@ -527,24 +527,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
527
527
self .input_batch .num_tokens [req_index ] = end_token_index
528
528
else :
529
529
req_data = scheduler_output .scheduled_cached_reqs
530
+ is_last_rank = get_pp_group ().is_last_rank
530
531
for i , req_id in enumerate (req_data .req_ids ):
531
532
req_state = self .requests [req_id ]
532
533
num_computed_tokens = req_data .num_computed_tokens [i ]
533
- new_token_ids = req_data .new_token_ids [i ]
534
534
new_block_ids = req_data .new_block_ids [i ]
535
535
resumed_from_preemption = req_data .resumed_from_preemption [i ]
536
536
537
537
req_state .num_computed_tokens = num_computed_tokens
538
- # Add the sampled token(s) from the previous step (if any).
539
- # This doesn't include "unverified" tokens like spec decode tokens.
540
- num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
541
- req_state .num_tokens )
542
- if num_new_tokens == 1 :
543
- # Avoid slicing list in most common case.
544
- req_state .output_token_ids .append (new_token_ids [- 1 ])
545
- elif num_new_tokens > 0 :
546
- req_state .output_token_ids .extend (
547
- new_token_ids [- num_new_tokens :])
538
+ if not is_last_rank :
539
+ new_token_ids = req_data .new_token_ids [i ]
540
+ # Add the sampled token(s) from the previous step (if any).
541
+ # This doesn't include "unverified" tokens like spec decode tokens.
542
+ num_new_tokens = (num_computed_tokens +
543
+ len (new_token_ids ) -
544
+ req_state .num_tokens )
545
+ if num_new_tokens == 1 :
546
+ # Avoid slicing list in most common case.
547
+ req_state .output_token_ids .append (new_token_ids [- 1 ])
548
+ elif num_new_tokens > 0 :
549
+ req_state .output_token_ids .extend (
550
+ new_token_ids [- num_new_tokens :])
548
551
# Update the block IDs.
549
552
if not resumed_from_preemption :
550
553
# Append the new blocks to the existing block IDs.
@@ -570,25 +573,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570
573
571
574
self .input_batch .block_table .append_row (
572
575
new_block_ids , req_index )
573
- # Add new_token_ids to token_ids_cpu.
574
- start_token_index = num_computed_tokens
575
- end_token_index = num_computed_tokens + len (new_token_ids )
576
- self .input_batch .token_ids_cpu [
577
- req_index ,
578
- start_token_index :end_token_index ] = new_token_ids
579
- self .input_batch .num_tokens_no_spec [
580
- req_index ] = end_token_index
581
- # Add spec_token_ids to token_ids_cpu.
582
- spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
583
- req_id , ())
584
- if spec_token_ids :
585
- start_index = end_token_index
586
- end_token_index += len (spec_token_ids )
576
+
577
+ if not is_last_rank :
578
+ # Add new_token_ids to token_ids_cpu.
579
+ start_token_index = num_computed_tokens
580
+ end_token_index = num_computed_tokens + len (new_token_ids )
587
581
self .input_batch .token_ids_cpu [
588
582
req_index ,
589
- start_index :end_token_index ] = spec_token_ids
590
- # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
591
- self .input_batch .num_tokens [req_index ] = end_token_index
583
+ start_token_index :end_token_index ] = new_token_ids
584
+ self .input_batch .num_tokens_no_spec [
585
+ req_index ] = end_token_index
586
+ # Add spec_token_ids to token_ids_cpu.
587
+ spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
588
+ req_id , ())
589
+ if spec_token_ids :
590
+ start_index = end_token_index
591
+ end_token_index += len (spec_token_ids )
592
+ self .input_batch .token_ids_cpu [
593
+ req_index ,
594
+ start_index :end_token_index ] = spec_token_ids
595
+ # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
596
+ self .input_batch .num_tokens [req_index ] = end_token_index
592
597
593
598
# Check if the batch has changed. If not, we can skip copying the
594
599
# sampling metadata from CPU to GPU.
@@ -1641,6 +1646,30 @@ def execute_model(
1641
1646
1642
1647
for i in discard_sampled_tokens_req_indices :
1643
1648
valid_sampled_token_ids [i ].clear ()
1649
+ if not vllm_version_is ("0.9.1" ):
1650
+ # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1651
+ # doesn't need to send them back.
1652
+ # NOTE(woosuk): As an exception, when using PP, the scheduler sends
1653
+ # the sampled tokens back, because there's no direct communication
1654
+ # between the first-stage worker and the last-stage worker.
1655
+ for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1656
+ if not sampled_ids :
1657
+ continue
1658
+
1659
+ start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
1660
+ end_idx = start_idx + len (sampled_ids )
1661
+ assert end_idx <= self .model_config .max_model_len , (
1662
+ "Sampled token IDs exceed the max model length. "
1663
+ f"Total number of tokens: { end_idx } > max_model_len: "
1664
+ f"{ self .model_config .max_model_len } " )
1665
+
1666
+ self .input_batch .token_ids_cpu [
1667
+ req_idx , start_idx :end_idx ] = sampled_ids
1668
+ self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
1669
+ self .input_batch .num_tokens [req_idx ] = end_idx
1670
+ req_id = self .input_batch .req_ids [req_idx ]
1671
+ req_state = self .requests [req_id ]
1672
+ req_state .output_token_ids .extend (sampled_ids )
1644
1673
1645
1674
spec_token_ids = self ._get_spec_token_ids (
1646
1675
valid_sampled_token_ids ,
0 commit comments