Skip to content

Commit 625a8be

Browse files
committed
[DP][V1] Fix rank set in DP scenario
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 21d4fb5 commit 625a8be

File tree

4 files changed

+72
-26
lines changed

4 files changed

+72
-26
lines changed

tests/multicard/test_data_parallel.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""
18+
Compare the outputs of vLLM with and without aclgraph.
19+
Run `pytest tests/multicard/test_data_parallel.py`.
20+
"""
21+
22+
import os
23+
24+
import pytest
25+
26+
from tests.conftest import VllmRunner
27+
from tests.model_utils import check_outputs_equal
28+
from vllm_ascend.utils import vllm_version_is
29+
30+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
31+
32+
33+
@pytest.mark.skipif(vllm_version_is("0.9.0"),
34+
reason="Data parallel only support on >= 0.9.1")
35+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
36+
reason="Data parallel only support on v1")
37+
@pytest.mark.parametrize("model", MODELS)
38+
@pytest.mark.parametrize("max_tokens", [32])
39+
def test_data_parallel_correctness(
40+
model: str,
41+
max_tokens: int,
42+
) -> None:
43+
example_prompts = [
44+
"Hello, my name is", "The president of the United States is",
45+
"The capital of France is", "The future of AI is"
46+
]
47+
48+
with VllmRunner(model_name=model,
49+
max_model_len=1024,
50+
max_num_seqs=16,
51+
data_parallel_size=2,
52+
distributed_executor_backend="mp") as vllm_model:
53+
vllm_dp_outputs = vllm_model.generate_greedy(example_prompts,
54+
max_tokens)
55+
56+
with VllmRunner(
57+
model_name=model,
58+
max_model_len=1024,
59+
max_num_seqs=16,
60+
) as vllm_model:
61+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
62+
63+
check_outputs_equal(
64+
outputs_0_lst=vllm_outputs,
65+
outputs_1_lst=vllm_dp_outputs,
66+
name_0="vllm_outputs",
67+
name_1="vllm_dp_outputs",
68+
)

vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
2+
import vllm
23
from torch.distributed import ProcessGroup
34
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
45
_get_default_timeout,
56
is_nccl_available)
67
from torch.distributed.rendezvous import rendezvous
7-
from vllm.distributed import utils
88

99

1010
def stateless_init_torch_distributed_process_group(
@@ -113,4 +113,4 @@ def stateless_init_torch_distributed_process_group(
113113
return pg
114114

115115

116-
utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group
116+
vllm.distributed.utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
import vllm.distributed
2222
import vllm.envs as envs
2323
from torch.distributed import ProcessGroup
24-
from vllm.config import ParallelConfig, VllmConfig
24+
from vllm.config import ParallelConfig
2525
from vllm.distributed.utils import \
2626
stateless_init_torch_distributed_process_group
27-
from vllm.v1.engine.core import DPEngineCoreProc
2827

2928

3029
def ascend_destroy_model_parallel():
@@ -79,21 +78,6 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
7978
return dp_group
8079

8180

82-
def _init_data_parallel(self, vllm_config: VllmConfig):
83-
# Configure NPUs and stateless process group for data parallel.
84-
dp_rank = vllm_config.parallel_config.data_parallel_rank
85-
dp_size = vllm_config.parallel_config.data_parallel_size
86-
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
87-
88-
assert dp_size > 1
89-
assert 0 <= local_dp_rank <= dp_rank < dp_size
90-
91-
self.local_dp_rank = local_dp_rank
92-
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
93-
self.current_wave = 0
94-
95-
9681
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
97-
DPEngineCoreProc._init_data_parallel = _init_data_parallel
9882
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
9983
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ def __init__(
7575
distributed_init_method=distributed_init_method,
7676
is_driver_worker=is_driver_worker)
7777

78-
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
79-
# vllm_ascend, we need to set the device id manually.
80-
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
81-
world_size = self.vllm_config.parallel_config.world_size
82-
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
83-
8478
# Try to import mindie_turbo to accelerate vLLM inference.
8579
try_register_lib(
8680
"mindie_turbo",
@@ -119,7 +113,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
119113

120114
def init_device(self):
121115
if self.device_config.device.type == "npu":
122-
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
116+
self.device = torch.device(f"npu:{self.local_rank}")
123117
NPUPlatform.set_device(self.device)
124118
NPUPlatform.empty_cache()
125119
self.init_npu_memory = NPUPlatform.mem_get_info()[0]

0 commit comments

Comments
 (0)