Skip to content

Commit 6ec64a3

Browse files
authored
[bugfix] some bugs maybe fail to run (#896)
### What this PR does / why we need it? Solve the bug that the graph mode is the same as p and d, and some other bugs. ### Does this PR introduce _any_ user-facing change? Wouldn't be ### How was this patch tested? Follow the end-to-end test Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
1 parent 92bc557 commit 6ec64a3

File tree

7 files changed

+15
-7
lines changed

7 files changed

+15
-7
lines changed

tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ def test_mtp_correctness(
8989
# Heuristic: expect at least 66% of the prompts to match exactly
9090
# Upon failure, inspect the outputs to check for inaccuracy.
9191
assert matches > int(0.66 * len(ref_outputs))
92-
del spec_llm
92+
del spec_llm

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class AscendMetadata:
129129
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
130130
attn_mask: Optional[torch.Tensor] = None
131131

132+
# For logging.
133+
num_input_tokens: int = 0 # Number of tokens including padding.
134+
132135

133136
class AscendAttentionMetadataBuilder:
134137

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
2121
return _ETP
2222

2323

24+
def model_parallel_initialized():
25+
return (_ETP is not None and _EP is not None)
26+
27+
2428
def init_ascend_model_parallel(
2529
expert_parallel_size: int = 1,
2630
expert_tensor_parallel_size: int = 1,
2731
world_size: Optional[int] = None,
2832
backend: Optional[str] = None,
2933
):
34+
if model_parallel_initialized():
35+
return
3036
assert torch.distributed.is_initialized()
3137
world_size = world_size or torch.distributed.get_world_size()
3238
backend = backend or torch.distributed.get_backend(

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def fused_experts_with_mc2(
6666
local_rank = torch.distributed.get_rank(group=ep_group)
6767
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
6868

69-
world_szie = torch.distributed.get_world_size()
70-
tp_size = world_szie // all_to_all_group_size
69+
tp_size = get_etp_group().world_size
7170
tp_rank = rank % tp_size
7271

7372
stage1_kwargs = {

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import vllm
2222
import vllm.distributed
23+
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2526
_get_default_timeout,
@@ -164,10 +165,9 @@ def parallel_config_get_dp_port(self) -> int:
164165
"""
165166
answer = self.data_parallel_master_port
166167
self.data_parallel_master_port += 1
167-
import os
168168

169169
# NOTE: Get port from envs directly when using torchrun
170-
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
170+
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
171171
return port
172172

173173

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,4 @@ def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,
219219

220220
global_indices_flat = global_indices[mask]
221221
values_flat = values[mask]
222-
out_ptr[global_indices_flat] = values_flat
222+
out_ptr[global_indices_flat] = values_flat

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def execute_model(
173173
scheduler_output: "SchedulerOutput",
174174
) -> Optional[ModelRunnerOutput]:
175175
output = self.model_runner.execute_model(scheduler_output)
176-
return output if self.rank == 0 else None
176+
return output if self.is_driver_worker else None
177177

178178
def load_model(self) -> None:
179179
self.model_runner.load_model()

0 commit comments

Comments
 (0)