From 47443e7b13a32878cf446bfe977872f7ff4e6248 Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Tue, 3 Jun 2025 09:33:22 +0800 Subject: [PATCH] fix some bugs Signed-off-by: ningbenzhe1 --- tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py | 2 +- vllm_ascend/attention/attention_v1.py | 3 +++ vllm_ascend/distributed/parallel_state.py | 6 ++++++ vllm_ascend/ops/fused_moe.py | 3 +-- .../patch/platform/patch_common/patch_distributed.py | 4 ++-- vllm_ascend/worker/mtp_proposer_v1.py | 2 +- vllm_ascend/worker/worker_v1.py | 2 +- 7 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py b/tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py index 46b5d66cea..2219a6f552 100644 --- a/tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py +++ b/tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py @@ -89,4 +89,4 @@ def test_mtp_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) - del spec_llm \ No newline at end of file + del spec_llm diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 675318ee5c..b00573a94e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -129,6 +129,9 @@ class AscendMetadata: attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + class AscendAttentionMetadataBuilder: diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 016dd6c1eb..2778a6ef27 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator: return _ETP +def model_parallel_initialized(): + return (_ETP is not None and _EP is not None) + + def init_ascend_model_parallel( expert_parallel_size: int = 1, expert_tensor_parallel_size: int = 1, world_size: Optional[int] = None, backend: Optional[str] = None, ): + if model_parallel_initialized(): + return assert torch.distributed.is_initialized() world_size = world_size or torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 74a292d576..2f5cea06ec 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -66,8 +66,7 @@ def fused_experts_with_mc2( local_rank = torch.distributed.get_rank(group=ep_group) all_to_all_group_size = torch.distributed.get_world_size(ep_group) - world_szie = torch.distributed.get_world_size() - tp_size = world_szie // all_to_all_group_size + tp_size = get_etp_group().world_size tp_rank = rank % tp_size stage1_kwargs = { diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index ac46ab0bf2..0b88264b4d 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -20,6 +20,7 @@ import torch import vllm import vllm.distributed +import vllm.envs as envs from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, @@ -164,10 +165,9 @@ def parallel_config_get_dp_port(self) -> int: """ answer = self.data_parallel_master_port self.data_parallel_master_port += 1 - import os # NOTE: Get port from envs directly when using torchrun - port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore + port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer return port diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 3a270597e7..8782df181f 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -219,4 +219,4 @@ def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, global_indices_flat = global_indices[mask] values_flat = values[mask] - out_ptr[global_indices_flat] = values_flat \ No newline at end of file + out_ptr[global_indices_flat] = values_flat diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 69476e256f..21c9955b4e 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -173,7 +173,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None + return output if self.is_driver_worker else None def load_model(self) -> None: self.model_runner.load_model()