diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e74ece3aea..8ead896048 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,9 +30,7 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig @@ -629,16 +627,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, - op=ReduceOp.MAX, - group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + self, num_tokens: int, + with_prefill: bool) -> tuple[torch.Tensor, bool]: + local_forward_metadata = torch.tensor([num_tokens, with_prefill], + device="npu", + dtype=torch.int32).unsqueeze(0) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata, dim=0) + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() + with_prefill = bool(global_forward_metadata[:, 1].any()) + return num_tokens_across_dp, with_prefill def get_eagle_atten_dict( self, @@ -1107,23 +1105,35 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + num_tokens_across_dp = None if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) + num_tokens_across_dp, with_prefill = \ + self._get_forward_metadata_across_dp(num_input_tokens, + with_prefill) + max_num_tokens = int(num_tokens_across_dp.max().item()) extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens extra_builder_kwargs['with_prefill_across_dp'] = with_prefill # Add graph_pad_size here if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) + max_num_tokens = (max_num_tokens + if self.dp_size > 1 else num_input_tokens) + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens) graph_pad_size = padded_batch_size - total_num_scheduled_tokens - extra_builder_kwargs['graph_pad_size'] = graph_pad_size + # If torchair graph is enabled and in decode mode, the dummy run + # batch size is set to the selected graph size. + dummy_num_tokens = padded_batch_size + else: + # If torchair graph is not enabled, or if with_prefill is True, the + # dummy run batch size is set to 1. + dummy_num_tokens = 1 + + if self.dp_size > 1: + assert num_tokens_across_dp is not None + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + dummy_num_tokens) if self.vllm_config.model_config.use_mla: attn_metadata = self.attn_metadata_builder.build( # type: ignore @@ -1196,7 +1206,8 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} if self.torchair_graph_enabled: @@ -1819,6 +1830,7 @@ def _dummy_run( is_compile: bool = False, with_prefill: bool = True, skip_attn: bool = True, + num_tokens_across_dp: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1871,9 +1883,11 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 02094f5c58..903873d86a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -277,16 +277,30 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner - max_num_tokens = 1 - with_prefill = False - if runner.dp_size > 1: - max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) + if runner.dp_size <= 1: + raise ValueError( + "Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={runner.dp_size}.") + + # If torchair graph is enabled, notify the other DP ranks that this is a + # dummy run by using '-1' as a flag for num_tokens. This will be + # replaced with the final determined graph size before the forward pass. + num_tokens_across_dp, with_prefill = \ + runner._get_forward_metadata_across_dp(-1, False) + if runner.torchair_graph_enabled and not with_prefill: - max_num_tokens = runner.select_torchair_padded_batch_size( + max_num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens = runner.select_torchair_padded_batch_size( max_num_tokens) - runner._dummy_run(max_num_tokens, + else: + num_tokens = 1 + + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + num_tokens) + runner._dummy_run(num_tokens, is_compile=False, + num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill) def _init_worker_distributed_environment(self) -> None: