@@ -386,8 +386,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
386
386
387
387
# kv role
388
388
self .is_kv_producer = False
389
+ self .is_kv_consumer = False
389
390
if vllm_config .kv_transfer_config is not None :
390
391
self .is_kv_producer = vllm_config .kv_transfer_config .is_kv_producer
392
+ self .is_kv_consumer = vllm_config .kv_transfer_config .is_kv_consumer
391
393
392
394
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
393
395
"""Update the cached states and the persistent batch with the scheduler
@@ -605,6 +607,15 @@ def _get_forward_metadata_across_dp(
605
607
) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
606
608
if self .dp_size == 1 :
607
609
return num_tokens , None , with_prefill , enable_dbo
610
+ if self .is_kv_producer :
611
+ return num_tokens , None , True , enable_dbo
612
+ if self .is_kv_consumer :
613
+ max_num_decode_tokens = self .max_num_reqs * self .decode_token_per_req
614
+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
615
+ self .dp_size ,
616
+ device = "cpu" ,
617
+ dtype = torch .int32 )
618
+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
608
619
609
620
forward_metadata = torch .tensor (
610
621
[num_tokens , with_prefill , not enable_dbo ],
0 commit comments