diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0ab79be45..db8e5f546f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -386,8 +386,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # kv role self.is_kv_producer = False + self.is_kv_consumer = False if vllm_config.kv_transfer_config is not None: self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -603,8 +605,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _get_forward_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: - if self.dp_size == 1: + if self.dp_size == 1 or self.in_profile_run: return num_tokens, None, with_prefill, enable_dbo + if self.is_kv_producer: + num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size, + device="cpu", + dtype=torch.int32) + return num_tokens, num_tokens_across_dp, True, enable_dbo + if self.is_kv_consumer: + max_num_decode_tokens = self.max_num_reqs * self.decode_token_per_req + num_tokens_across_dp = torch.tensor([max_num_decode_tokens] * + self.dp_size, + device="cpu", + dtype=torch.int32) + return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo forward_metadata = torch.tensor( [num_tokens, with_prefill, not enable_dbo],