Skip to content

Commit bb42211

Browse files
committed
[DP][V1] Fix rank set in DP scenario
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 1fd25e9 commit bb42211

File tree

3 files changed

+70
-24
lines changed

3 files changed

+70
-24
lines changed

tests/multicard/test_data_parallel.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""
18+
Compare the outputs of vLLM with and without aclgraph.
19+
Run `pytest tests/multicard/test_data_parallel.py`.
20+
"""
21+
22+
import os
23+
24+
import pytest
25+
26+
from tests.conftest import VllmRunner
27+
from tests.model_utils import check_outputs_equal
28+
from vllm_ascend.utils import vllm_version_is
29+
30+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
31+
32+
33+
@pytest.mark.skipif(vllm_version_is("0.9.0"),
34+
reason="Data parallel only support on >= 0.9.1")
35+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
36+
reason="Data parallel only support on v1")
37+
@pytest.mark.parametrize("model", MODELS)
38+
@pytest.mark.parametrize("max_tokens", [32])
39+
def test_data_parallel_correctness(
40+
model: str,
41+
max_tokens: int,
42+
) -> None:
43+
example_prompts = [
44+
"Hello, my name is", "The president of the United States is",
45+
"The capital of France is", "The future of AI is"
46+
]
47+
48+
with VllmRunner(model_name=model,
49+
max_model_len=1024,
50+
max_num_seqs=16,
51+
data_parallel_size=2,
52+
distributed_executor_backend="mp") as vllm_model:
53+
vllm_dp_outputs = vllm_model.generate_greedy(example_prompts,
54+
max_tokens)
55+
56+
with VllmRunner(
57+
model_name=model,
58+
max_model_len=1024,
59+
max_num_seqs=16,
60+
) as vllm_model:
61+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
62+
63+
check_outputs_equal(
64+
outputs_0_lst=vllm_outputs,
65+
outputs_1_lst=vllm_dp_outputs,
66+
name_0="vllm_outputs",
67+
name_1="vllm_dp_outputs",
68+
)

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
import vllm.distributed
2222
import vllm.envs as envs
2323
from torch.distributed import ProcessGroup
24-
from vllm.config import ParallelConfig, VllmConfig
24+
from vllm.config import ParallelConfig
2525
from vllm.distributed.utils import \
2626
stateless_init_torch_distributed_process_group
27-
from vllm.v1.engine.core import DPEngineCoreProc
2827

2928

3029
def ascend_destroy_model_parallel():
@@ -79,21 +78,6 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
7978
return dp_group
8079

8180

82-
def _init_data_parallel(self, vllm_config: VllmConfig):
83-
# Configure NPUs and stateless process group for data parallel.
84-
dp_rank = vllm_config.parallel_config.data_parallel_rank
85-
dp_size = vllm_config.parallel_config.data_parallel_size
86-
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
87-
88-
assert dp_size > 1
89-
assert 0 <= local_dp_rank <= dp_rank < dp_size
90-
91-
self.local_dp_rank = local_dp_rank
92-
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
93-
self.current_wave = 0
94-
95-
9681
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
97-
DPEngineCoreProc._init_data_parallel = _init_data_parallel
9882
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
9983
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ def __init__(
7575
distributed_init_method=distributed_init_method,
7676
is_driver_worker=is_driver_worker)
7777

78-
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
79-
# vllm_ascend, we need to set the device id manually.
80-
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
81-
world_size = self.vllm_config.parallel_config.world_size
82-
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
83-
8478
# Try to import mindie_turbo to accelerate vLLM inference.
8579
try_register_lib(
8680
"mindie_turbo",
@@ -124,7 +118,7 @@ def initialize_cache(self, num_gpu_blocks: int,
124118

125119
def init_device(self):
126120
if self.device_config.device.type == "npu":
127-
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
121+
self.device = torch.device(f"npu:{self.local_rank}")
128122
NPUPlatform.set_device(self.device)
129123
NPUPlatform.empty_cache()
130124
self.init_npu_memory = NPUPlatform.mem_get_info()[0]

0 commit comments

Comments
 (0)