Skip to content

Commit 52317f9

Browse files
authored
[DP] Tiny fix of dp and update example (#1273)
### What this PR does / why we need it? Add `max_num_tokens_across_dp` to AscendMetadata to fix dp This pr fixes the bug introduced by #1229, which add an arg `max_num_tokens_across_dp` when dp_size > 1. Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent c1c5d56 commit 52317f9

File tree

7 files changed

+327
-172
lines changed

7 files changed

+327
-172
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,10 @@ jobs:
363363
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
364364
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
365365
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
366-
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py
366+
pytest -sv tests/e2e/multicard/test_data_parallel.py
367+
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
368+
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
369+
--ignore=tests/e2e/multicard/test_data_parallel.py
367370
368371
- name: Run vllm-project/vllm-ascend test on V0 engine
369372
if: ${{ github.event_name == 'schedule' }}
@@ -380,4 +383,7 @@ jobs:
380383
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
381384
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
382385
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
383-
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py
386+
pytest -sv tests/e2e/multicard/test_data_parallel.py
387+
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
388+
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
389+
--ignore=tests/e2e/multicard/test_data_parallel.py

examples/dp_offline/data_parallel.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

examples/dp_offline/run_dp.sh

Lines changed: 0 additions & 19 deletions
This file was deleted.

examples/offline_data_parallel.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
18+
#
19+
"""
20+
Usage:
21+
Single node:
22+
Dense models:
23+
python examples/offline_data_parallel.py \
24+
--model="Qwen/Qwen2.5-0.5B-Instruct" \
25+
--dp-size=2 \
26+
--tp-size=2
27+
MOE models:
28+
python examples/offline_data_parallel.py \
29+
--model="ibm-research/PowerMoE-3b" \
30+
--dp-size=2 \
31+
--tp-size=2 \
32+
--enable-expert-parallel
33+
34+
Multi-node:
35+
Node 0 (assume the node has ip of 10.99.48.128):
36+
python examples/offline_data_parallel.py \
37+
--model="ibm-research/PowerMoE-3b" \
38+
--dp-size=2 \
39+
--tp-size=2 \
40+
--node-size=2 \
41+
--node-rank=0 \
42+
--enable-expert-parallel \
43+
--master-addr=10.99.48.128 \
44+
--master-port=13345
45+
Node 1:
46+
python examples/offline_data_parallel.py \
47+
--model="ibm-research/PowerMoE-3b" \
48+
--dp-size=2 \
49+
--tp-size=2 \
50+
--node-size=2 \
51+
--node-rank=1 \
52+
--enable-expert-parallel \
53+
--master-addr=10.99.48.128 \
54+
--master-port=13345
55+
"""
56+
57+
import os
58+
from time import sleep
59+
60+
from vllm import LLM, SamplingParams
61+
from vllm.utils import get_open_port
62+
63+
64+
def parse_args():
65+
import argparse
66+
67+
parser = argparse.ArgumentParser(description="Data Parallel Inference")
68+
parser.add_argument(
69+
"--model",
70+
type=str,
71+
default="ibm-research/PowerMoE-3b",
72+
help="Model name or path",
73+
)
74+
parser.add_argument("--dp-size",
75+
type=int,
76+
default=2,
77+
help="Data parallel size")
78+
parser.add_argument("--tp-size",
79+
type=int,
80+
default=1,
81+
help="Tensor parallel size")
82+
parser.add_argument("--node-size",
83+
type=int,
84+
default=1,
85+
help="Total number of nodes")
86+
parser.add_argument("--node-rank",
87+
type=int,
88+
default=0,
89+
help="Rank of the current node")
90+
parser.add_argument("--master-addr",
91+
type=str,
92+
default="",
93+
help="Master node IP address")
94+
parser.add_argument("--master-port",
95+
type=int,
96+
default=0,
97+
help="Master node port")
98+
parser.add_argument("--enforce-eager",
99+
action="store_true",
100+
help="Enforce eager mode execution.")
101+
parser.add_argument("--trust-remote-code",
102+
action="store_true",
103+
help="Trust remote code.")
104+
parser.add_argument("--enable-expert-parallel",
105+
action="store_true",
106+
help="Enable expert parallel, used in MOE models.")
107+
return parser.parse_args()
108+
109+
110+
def main(
111+
model,
112+
dp_size,
113+
local_dp_rank,
114+
global_dp_rank,
115+
dp_master_ip,
116+
dp_master_port,
117+
GPUs_per_dp_rank,
118+
enable_expert_parallel,
119+
enforce_eager,
120+
trust_remote_code,
121+
):
122+
# DP only support on V1 engine
123+
os.environ["VLLM_USE_V1"] = "1"
124+
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
125+
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
126+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
127+
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
128+
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
129+
130+
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
131+
# engine processes.
132+
133+
# Sample prompts.
134+
prompts = [
135+
"Hello, my name is",
136+
"The president of the United States is",
137+
"The capital of France is",
138+
"The future of AI is",
139+
] * 100
140+
141+
# with DP, each rank should process different prompts.
142+
# usually all the DP ranks process a full dataset,
143+
# and each rank processes a different part of the dataset.
144+
floor = len(prompts) // dp_size
145+
remainder = len(prompts) % dp_size
146+
147+
# Distribute prompts into even groups.
148+
def start(rank):
149+
return rank * floor + min(rank, remainder)
150+
151+
prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)]
152+
if len(prompts) == 0:
153+
# if any rank has no prompts to process,
154+
# we need to set a placeholder prompt
155+
prompts = ["Placeholder"]
156+
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
157+
158+
# Create a sampling params object.
159+
# since we are doing data parallel, every rank can have different
160+
# sampling params. here we set different max_tokens for different
161+
# ranks for demonstration.
162+
sampling_params = SamplingParams(temperature=0.8,
163+
top_p=0.95,
164+
max_tokens=[16, 20][global_dp_rank % 2])
165+
166+
# Create an LLM.
167+
llm = LLM(
168+
model=model,
169+
tensor_parallel_size=GPUs_per_dp_rank,
170+
enforce_eager=enforce_eager,
171+
enable_expert_parallel=enable_expert_parallel,
172+
trust_remote_code=trust_remote_code,
173+
)
174+
outputs = llm.generate(prompts, sampling_params)
175+
# Print the outputs.
176+
for i, output in enumerate(outputs):
177+
if i >= 5:
178+
# print only 5 outputs
179+
break
180+
prompt = output.prompt
181+
generated_text = output.outputs[0].text
182+
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
183+
f"Generated text: {generated_text!r}")
184+
185+
# Give engines time to pause their processing loops before exiting.
186+
sleep(1)
187+
188+
189+
if __name__ == "__main__":
190+
args = parse_args()
191+
192+
dp_size = args.dp_size
193+
tp_size = args.tp_size
194+
node_size = args.node_size
195+
node_rank = args.node_rank
196+
197+
if node_size == 1:
198+
dp_master_ip = "127.0.0.1"
199+
dp_master_port = get_open_port()
200+
else:
201+
dp_master_ip = args.master_addr
202+
dp_master_port = args.master_port
203+
204+
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
205+
dp_per_node = dp_size // node_size
206+
207+
from multiprocessing import Process
208+
209+
procs = []
210+
for local_dp_rank, global_dp_rank in enumerate(
211+
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
212+
proc = Process(
213+
target=main,
214+
args=(
215+
args.model,
216+
dp_size,
217+
local_dp_rank,
218+
global_dp_rank,
219+
dp_master_ip,
220+
dp_master_port,
221+
tp_size,
222+
args.enable_expert_parallel,
223+
args.enforce_eager,
224+
args.trust_remote_code,
225+
),
226+
)
227+
proc.start()
228+
procs.append(proc)
229+
exit_code = 0
230+
for proc in procs:
231+
proc.join(timeout=300)
232+
if proc.exitcode is None:
233+
print(
234+
f"Killing process {proc.pid} that didn't stop within 5 minutes."
235+
)
236+
proc.kill()
237+
exit_code = 1
238+
elif proc.exitcode:
239+
exit_code = proc.exitcode
240+
241+
exit(exit_code)

0 commit comments

Comments
 (0)