Skip to content

[0.9.1][PD][Perf] Avoid performing cpu all_reduce in disaggregated-prefill scenario. #1644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: v0.9.1-dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
Expand Down Expand Up @@ -603,8 +605,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
def _get_forward_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
if self.dp_size == 1 or self.in_profile_run:
return num_tokens, None, with_prefill, enable_dbo
if self.is_kv_producer:
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_across_dp, True, enable_dbo
if self.is_kv_consumer:
max_num_decode_tokens = self.max_num_reqs * self.decode_token_per_req
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo

forward_metadata = torch.tensor(
[num_tokens, with_prefill, not enable_dbo],
Expand Down