Skip to content

Commit 5305a2c

Browse files
authored
[Bugfix] Tweak distributed process group initialization and add dummy… (#816)
fix batch execution method to enable DP in V1 Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 4df1e99 commit 5305a2c

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def ascend_destroy_model_parallel():
4747
destory_ascend_model_parallel()
4848

4949

50-
def ascend_stateless_init_torch_distributed_process_group(
50+
def stateless_init_torch_distributed_process_group(
5151
host: str, port: int, rank: int, world_size: int,
5252
backend: str) -> ProcessGroup:
5353
"""
@@ -96,10 +96,16 @@ def ascend_stateless_init_torch_distributed_process_group(
9696
# different systems (e.g. RPC) in case the store is multi-tenant.
9797
prefix_store = PrefixStore(init_method, store)
9898

99+
# TODO(Yizhou): The reason we need to set options while vllm does not
100+
# seems to be related to the version of PyTorch. In the latest version,
101+
# there is no need to set options. While in the older version, 2.5.1
102+
# specifically, we need to set options.
103+
options = ProcessGroup.Options(backend=backend)
99104
pg: ProcessGroup = ProcessGroup(
100105
prefix_store,
101106
group_rank,
102107
group_size,
108+
options,
103109
)
104110
if backend == "gloo":
105111
from torch.distributed.distributed_c10d import ProcessGroupGloo
@@ -136,7 +142,10 @@ def ascend_stateless_init_torch_distributed_process_group(
136142
else:
137143
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
138144

139-
pg._set_default_backend(backend_type)
145+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
146+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
147+
# after the latest version is released.
148+
# pg._set_default_backend(backend_type)
140149
backend_class._set_sequence_number_for_group()
141150

142151
pg._register_backend(device, backend_type, backend_class)
@@ -163,20 +172,21 @@ def parallel_config_get_dp_port(self) -> int:
163172

164173

165174
def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
166-
from vllm.distributed.utils import \
167-
stateless_init_torch_distributed_process_group
168-
175+
# TODO(Yizhou): Currently we have to set the backend to gloo
176+
# because in vllm.config.ParallelConfig.has_unfinished_dp the
177+
# device is set to cpu. We need to fix this in the future.
178+
# We need to compare the performance of gloo and hccl and then
179+
# decide which one to use.
169180
dp_group = stateless_init_torch_distributed_process_group(
170181
self.data_parallel_master_ip,
171182
self.get_next_dp_init_port(),
172183
self.data_parallel_rank,
173184
self.data_parallel_size,
174-
backend="hccl")
185+
backend="gloo")
175186

176187
return dp_group
177188

178189

179190
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
180-
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
181191
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
182192
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def profile(self, is_start: bool = True):
216216
else:
217217
self.profiler.stop()
218218

219+
def execute_dummy_batch(self) -> None:
220+
self.model_runner._dummy_run(1)
221+
219222
def _init_worker_distributed_environment(self) -> None:
220223
"""Initialize the distributed environment."""
221224
additional_config = self.vllm_config.additional_config

0 commit comments

Comments
 (0)