From d8e315969bf91fbe42e9283b81cf88c39a512e41 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 19:51:56 +0800 Subject: [PATCH 01/11] feat: optimize forward metadata collection across dp ranks Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 36 ++++++++++++++++----------- vllm_ascend/worker/worker_v1.py | 25 ++++++++++++++----- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e74ece3aea..0c123337ea 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -628,17 +628,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if batch_changed: self.input_batch.refresh_sampling_metadata() - def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, - op=ReduceOp.MAX, - group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + def _get_forward_metadata_across_dp(self, num_tokens: int, + with_prefill: bool) -> tuple[int, bool]: + local_forward_metadata = torch.tensor([num_tokens, with_prefill], + device="npu", dtype=torch.int32) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata) + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() + with_prefill = bool(global_forward_metadata[:, 1].any()) + return num_tokens_across_dp, with_prefill def get_eagle_atten_dict( self, @@ -1107,9 +1105,12 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + num_tokens_across_dp = None if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) + num_tokens_across_dp, with_prefill = \ + self._get_forward_metadata_across_dp(num_input_tokens, + with_prefill) + max_num_tokens = int(num_tokens_across_dp.max().item()) extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens extra_builder_kwargs['with_prefill_across_dp'] = with_prefill @@ -1118,6 +1119,8 @@ def _process_reqs( if self.dp_size > 1: padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + padded_batch_size) else: padded_batch_size = self.select_torchair_padded_batch_size( total_num_scheduled_tokens) @@ -1196,7 +1199,8 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} if self.torchair_graph_enabled: @@ -1819,6 +1823,7 @@ def _dummy_run( is_compile: bool = False, with_prefill: bool = True, skip_attn: bool = True, + num_tokens_across_dp: Optional[int] = None, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1873,7 +1878,8 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 02094f5c58..5f63cd6024 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -277,16 +277,29 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner - max_num_tokens = 1 + + # If torchair graph is enabled, notify the other DP ranks that this is a + # dummy run by using '-1' as a flag for num_tokens. This will be + # replaced with the final determined graph size before the forward pass. + num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill + else 1) + num_tokens_across_dp = None with_prefill = False + if runner.dp_size > 1: - max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) + num_tokens_across_dp, with_prefill = \ + runner._get_forward_metadata_across_dp(num_tokens, with_prefill) + num_tokens = int(num_tokens_across_dp.max().item()) + if runner.torchair_graph_enabled and not with_prefill: - max_num_tokens = runner.select_torchair_padded_batch_size( - max_num_tokens) - runner._dummy_run(max_num_tokens, + num_tokens = runner.select_torchair_padded_batch_size(num_tokens) + if num_tokens_across_dp is not None: + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + num_tokens) + + runner._dummy_run(num_tokens, is_compile=False, + num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill) def _init_worker_distributed_environment(self) -> None: From fe1f5c0741cdd192af78007b6bf9ad885ec54542 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 20:08:48 +0800 Subject: [PATCH 02/11] fix: change num_tokens_across_dp type from int to torch.Tensor Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0c123337ea..341fb5f720 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1823,7 +1823,7 @@ def _dummy_run( is_compile: bool = False, with_prefill: bool = True, skip_attn: bool = True, - num_tokens_across_dp: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively From 7f1da12cff819f4cae67c9e36511671813254cfd Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:20:00 +0800 Subject: [PATCH 03/11] refactor: remove unused imports from model_runner_v1.py Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 341fb5f720..7588754de0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,9 +30,7 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig From 1cffb5237facf53f04758c45f0bf52a46feada0f Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:22:55 +0800 Subject: [PATCH 04/11] fix: correct handling the num_tokens for dummy run Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 5f63cd6024..f883e586b6 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -277,26 +277,25 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner + assert runner.dp_size > 1, "Dummy batch execution should only be " \ + "performed with data parallelism enabled, but got " \ + f"dp_size={runner.dp_size}." # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be # replaced with the final determined graph size before the forward pass. - num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill - else 1) - num_tokens_across_dp = None - with_prefill = False - - if runner.dp_size > 1: - num_tokens_across_dp, with_prefill = \ - runner._get_forward_metadata_across_dp(num_tokens, with_prefill) - num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens_across_dp, with_prefill = \ + runner._get_forward_metadata_across_dp(-1, False) if runner.torchair_graph_enabled and not with_prefill: - num_tokens = runner.select_torchair_padded_batch_size(num_tokens) - if num_tokens_across_dp is not None: - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - num_tokens) + max_num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens = runner.select_torchair_padded_batch_size( + max_num_tokens) + else: + num_tokens = 1 + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + num_tokens) runner._dummy_run(num_tokens, is_compile=False, num_tokens_across_dp=num_tokens_across_dp, From b9c7f3022a3efcd84ef33886f177baf0f16cc8a7 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:30:51 +0800 Subject: [PATCH 05/11] chore: lint Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7588754de0..d3994af868 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -626,10 +626,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if batch_changed: self.input_batch.refresh_sampling_metadata() - def _get_forward_metadata_across_dp(self, num_tokens: int, - with_prefill: bool) -> tuple[int, bool]: + def _get_forward_metadata_across_dp( + self, num_tokens: int, + with_prefill: bool) -> tuple[torch.Tensor, bool]: local_forward_metadata = torch.tensor([num_tokens, with_prefill], - device="npu", dtype=torch.int32) + device="npu", + dtype=torch.int32) global_forward_metadata = get_dp_group().all_gather( local_forward_metadata) num_tokens_across_dp = global_forward_metadata[:, 0].cpu() @@ -1874,10 +1876,11 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) From e7064e0c2fab2c949738111d4a8da4c6a6db9b9e Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:41:53 +0800 Subject: [PATCH 06/11] fix: improve handling of max_num_tokens Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d3994af868..e151a84055 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1116,17 +1116,18 @@ def _process_reqs( # Add graph_pad_size here if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - padded_batch_size) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) + max_num_tokens = (max_num_tokens + if self.dp_size > 1 else num_input_tokens) + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens) + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + padded_batch_size) graph_pad_size = padded_batch_size - total_num_scheduled_tokens - extra_builder_kwargs['graph_pad_size'] = graph_pad_size + else: + # If torchair graph is not enabled, or if with_prefill is True, the + # dummy run batch size is set to 1. + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1) if self.vllm_config.model_config.use_mla: attn_metadata = self.attn_metadata_builder.build( # type: ignore From d425f5be4dc4030ae680e849861ad12279f97849 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:56:01 +0800 Subject: [PATCH 07/11] fix: update dummy run batch size handling Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e151a84055..1e1356d645 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1120,14 +1120,19 @@ def _process_reqs( if self.dp_size > 1 else num_input_tokens) padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - padded_batch_size) graph_pad_size = padded_batch_size - total_num_scheduled_tokens extra_builder_kwargs['graph_pad_size'] = graph_pad_size + # If torchair graph is enabled and in decode mode, the dummy run + # batch size is set to the selected graph size. + dummy_num_tokens = padded_batch_size else: # If torchair graph is not enabled, or if with_prefill is True, the # dummy run batch size is set to 1. - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1) + dummy_num_tokens = 1 + + if self.dp_size > 1: + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + dummy_num_tokens) if self.vllm_config.model_config.use_mla: attn_metadata = self.attn_metadata_builder.build( # type: ignore From d0b8fd382aa87982663796c8be9ce814f2fc3e06 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:16:39 +0800 Subject: [PATCH 08/11] fix: add assertion for num_tokens_across_dp in NPUModelRunner Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1e1356d645..6a7080c6ba 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1131,6 +1131,7 @@ def _process_reqs( dummy_num_tokens = 1 if self.dp_size > 1: + assert num_tokens_across_dp is not None num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, dummy_num_tokens) From 05d8b6a948a4a00bca14270ba4890f9bde1fdb21 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:17:18 +0800 Subject: [PATCH 09/11] fix: change assertion to exception for dummy batch execution in NPUWorker Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index f883e586b6..649395c2cb 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -277,9 +277,10 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner - assert runner.dp_size > 1, "Dummy batch execution should only be " \ - "performed with data parallelism enabled, but got " \ - f"dp_size={runner.dp_size}." + if runner.dp_size <= 1: + raise ValueError("Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={runner.dp_size}.") # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be From e3a9cd5b911a2c86ba713092298a5725adcc797d Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:39:18 +0800 Subject: [PATCH 10/11] chore: lint Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 649395c2cb..903873d86a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -278,9 +278,10 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner if runner.dp_size <= 1: - raise ValueError("Dummy batch execution should only be " - "performed with data parallelism enabled, but got " - f"dp_size={runner.dp_size}.") + raise ValueError( + "Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={runner.dp_size}.") # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be From 77261469acbb7afb655f581b5a99c3d1f10a1c4f Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 8 Jul 2025 14:46:06 +0800 Subject: [PATCH 11/11] Update vllm_ascend/worker/model_runner_v1.py Co-authored-by: Angazenn <92204292+Angazenn@users.noreply.github.com> Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6a7080c6ba..8ead896048 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -631,9 +631,9 @@ def _get_forward_metadata_across_dp( with_prefill: bool) -> tuple[torch.Tensor, bool]: local_forward_metadata = torch.tensor([num_tokens, with_prefill], device="npu", - dtype=torch.int32) + dtype=torch.int32).unsqueeze(0) global_forward_metadata = get_dp_group().all_gather( - local_forward_metadata) + local_forward_metadata, dim=0) num_tokens_across_dp = global_forward_metadata[:, 0].cpu() with_prefill = bool(global_forward_metadata[:, 1].any()) return num_tokens_across_dp, with_prefill