Skip to content

Commit f55250a

Browse files
committed
remove cpu all_reduce in disaggregated-prefill scenario
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e1d282d commit f55250a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
386386

387387
# kv role
388388
self.is_kv_producer = False
389+
self.is_kv_consumer = False
389390
if vllm_config.kv_transfer_config is not None:
390391
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
392+
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
391393

392394
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
393395
"""Update the cached states and the persistent batch with the scheduler
@@ -605,6 +607,15 @@ def _get_forward_metadata_across_dp(
605607
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
606608
if self.dp_size == 1:
607609
return num_tokens, None, with_prefill, enable_dbo
610+
if self.is_kv_producer:
611+
return num_tokens, None, True, enable_dbo
612+
if self.is_kv_consumer:
613+
max_num_decode_tokens = self.max_num_reqs * self.decode_token_per_req
614+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
615+
self.dp_size,
616+
device="cpu",
617+
dtype=torch.int32)
618+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
608619

609620
forward_metadata = torch.tensor(
610621
[num_tokens, with_prefill, not enable_dbo],

0 commit comments

Comments
 (0)