Skip to content

Commit 6003afa

Browse files
authored
[BugFix] Fix data parallel (#940)
### What this PR does / why we need it? With this PR, we can migrate to the native `data_parallel.py` in vllm examples and remove the version in vllm-ascend. At present, `ASCEND_RT_VISIBLE_DEVICES` introduces considerable difficulties; therefore, we must employ a temporary workaround and manually specify the device. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent eec6068 commit 6003afa

File tree

5 files changed

+191
-115
lines changed

5 files changed

+191
-115
lines changed

vllm_ascend/patch/platform/patch_0_9_0/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import vllm_ascend.patch.platform.patch_0_9_0.patch_distributed # noqa
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
from torch.distributed import ProcessGroup
3+
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
4+
_get_default_timeout,
5+
is_nccl_available)
6+
from torch.distributed.rendezvous import rendezvous
7+
from vllm.distributed import utils
8+
9+
10+
def stateless_init_torch_distributed_process_group(
11+
host: str, port: int, rank: int, world_size: int,
12+
backend: str) -> ProcessGroup:
13+
"""
14+
A replacement for `torch.distributed.init_process_group` that does not
15+
pollute the global state. The created ProcessGroup object can be used for
16+
some operations such as `allreduce`, because it does not depend on the
17+
global rank. However, some operations such as `broadcast` cannot be used
18+
because it depends on the global rank.
19+
20+
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
21+
22+
This function is useful when we are not sure about the total number of
23+
processes in the process group. For example, we may have process
24+
1, 2, ..., 8 who want to communicate, and process 9 might be the same
25+
process as process 1, or it might be a different process; process 10
26+
might be the same process as process 5, or it might be a different process.
27+
In this case, how can we reliably form a communication channel within
28+
process 9 and 10, without affecting the communication channel within
29+
process 1, 2, ..., 8?
30+
31+
One possible solution is to figure out if process 9 and 10 are the same
32+
as process 1 and 5 beforehand, and then form a communication channel
33+
based on the information, adjusting the ranks and world_size etc. However,
34+
figuring out the information is not always easy, and it will interfere
35+
with the main communication channel.
36+
37+
Our solution is to always form a communication channel with process 1, 2,
38+
..., 8, and then use this function to form another communication channel
39+
with process 9 and 10. This way, regardless of whether process 9 and 10
40+
are the same as process 1 and 5, the main communication channel is
41+
always formed with process 1, 2, ..., 8, and the additional communication
42+
channel is formed with process 9 and 10.
43+
"""
44+
init_method = f"tcp://{host}:{port}"
45+
backend = Backend(backend) # it is basically string
46+
timeout = _get_default_timeout(backend)
47+
48+
store, rank, world_size = next(
49+
rendezvous(init_method, rank, world_size, timeout=timeout))
50+
store.set_timeout(timeout)
51+
52+
group_rank = rank
53+
group_size = world_size
54+
55+
# Use a PrefixStore to avoid accidental overrides of keys used by
56+
# different systems (e.g. RPC) in case the store is multi-tenant.
57+
prefix_store = PrefixStore(init_method, store)
58+
59+
# TODO(Yizhou): The reason we need to set options while vllm does not
60+
# seems to be related to the version of PyTorch. In the latest version,
61+
# there is no need to set options. While in the older version, 2.5.1
62+
# specifically, we need to set options.
63+
options = ProcessGroup.Options(backend=backend)
64+
pg: ProcessGroup = ProcessGroup(
65+
prefix_store,
66+
group_rank,
67+
group_size,
68+
options,
69+
)
70+
if backend == "gloo":
71+
from torch.distributed.distributed_c10d import ProcessGroupGloo
72+
backend_class = ProcessGroupGloo(prefix_store,
73+
group_rank,
74+
group_size,
75+
timeout=timeout)
76+
backend_type = ProcessGroup.BackendType.GLOO
77+
device = torch.device("cpu")
78+
elif backend == "nccl":
79+
assert is_nccl_available()
80+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
81+
82+
backend_options = ProcessGroupNCCL.Options()
83+
backend_options._timeout = timeout
84+
85+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
86+
backend_options)
87+
backend_type = ProcessGroup.BackendType.NCCL
88+
device = torch.device("cuda")
89+
elif backend == "hccl":
90+
from torch.distributed import is_hccl_available
91+
assert is_hccl_available()
92+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
93+
backend_options = ProcessGroupHCCL.Options()
94+
backend_options._timeout = timeout
95+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
96+
backend_options)
97+
device = torch.device("npu")
98+
backend_class._set_sequence_number_for_group()
99+
backend_type = ProcessGroup.BackendType.CUSTOM
100+
pg._register_backend(device, backend_type, backend_class)
101+
return pg
102+
else:
103+
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
104+
105+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
106+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
107+
# after the latest version is released.
108+
# pg._set_default_backend(backend_type)
109+
backend_class._set_sequence_number_for_group()
110+
111+
pg._register_backend(device, backend_type, backend_class)
112+
113+
return pg
114+
115+
116+
utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 21 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
# Adapted from vllm/model_executor/models/qwen2_vl.py
1818
# This file is a part of the vllm-ascend project.
1919

20-
import torch
2120
import vllm
2221
import vllm.distributed
2322
import vllm.envs as envs
2423
from torch.distributed import ProcessGroup
25-
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
26-
_get_default_timeout,
27-
is_nccl_available)
28-
from torch.distributed.rendezvous import rendezvous
29-
from vllm.config import ParallelConfig
24+
from vllm.config import ParallelConfig, VllmConfig
25+
from vllm.distributed.utils import \
26+
stateless_init_torch_distributed_process_group
27+
from vllm.v1.engine.core import DPEngineCoreProc
3028

3129

3230
def ascend_destroy_model_parallel():
@@ -48,112 +46,6 @@ def ascend_destroy_model_parallel():
4846
destory_ascend_model_parallel()
4947

5048

51-
def stateless_init_torch_distributed_process_group(
52-
host: str, port: int, rank: int, world_size: int,
53-
backend: str) -> ProcessGroup:
54-
"""
55-
A replacement for `torch.distributed.init_process_group` that does not
56-
pollute the global state. The created ProcessGroup object can be used for
57-
some operations such as `allreduce`, because it does not depend on the
58-
global rank. However, some operations such as `broadcast` cannot be used
59-
because it depends on the global rank.
60-
61-
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
62-
63-
This function is useful when we are not sure about the total number of
64-
processes in the process group. For example, we may have process
65-
1, 2, ..., 8 who want to communicate, and process 9 might be the same
66-
process as process 1, or it might be a different process; process 10
67-
might be the same process as process 5, or it might be a different process.
68-
In this case, how can we reliably form a communication channel within
69-
process 9 and 10, without affecting the communication channel within
70-
process 1, 2, ..., 8?
71-
72-
One possible solution is to figure out if process 9 and 10 are the same
73-
as process 1 and 5 beforehand, and then form a communication channel
74-
based on the information, adjusting the ranks and world_size etc. However,
75-
figuring out the information is not always easy, and it will interfere
76-
with the main communication channel.
77-
78-
Our solution is to always form a communication channel with process 1, 2,
79-
..., 8, and then use this function to form another communication channel
80-
with process 9 and 10. This way, regardless of whether process 9 and 10
81-
are the same as process 1 and 5, the main communication channel is
82-
always formed with process 1, 2, ..., 8, and the additional communication
83-
channel is formed with process 9 and 10.
84-
"""
85-
init_method = f"tcp://{host}:{port}"
86-
backend = Backend(backend) # it is basically string
87-
timeout = _get_default_timeout(backend)
88-
89-
store, rank, world_size = next(
90-
rendezvous(init_method, rank, world_size, timeout=timeout))
91-
store.set_timeout(timeout)
92-
93-
group_rank = rank
94-
group_size = world_size
95-
96-
# Use a PrefixStore to avoid accidental overrides of keys used by
97-
# different systems (e.g. RPC) in case the store is multi-tenant.
98-
prefix_store = PrefixStore(init_method, store)
99-
100-
# TODO(Yizhou): The reason we need to set options while vllm does not
101-
# seems to be related to the version of PyTorch. In the latest version,
102-
# there is no need to set options. While in the older version, 2.5.1
103-
# specifically, we need to set options.
104-
options = ProcessGroup.Options(backend=backend)
105-
pg: ProcessGroup = ProcessGroup(
106-
prefix_store,
107-
group_rank,
108-
group_size,
109-
options,
110-
)
111-
if backend == "gloo":
112-
from torch.distributed.distributed_c10d import ProcessGroupGloo
113-
backend_class = ProcessGroupGloo(prefix_store,
114-
group_rank,
115-
group_size,
116-
timeout=timeout)
117-
backend_type = ProcessGroup.BackendType.GLOO
118-
device = torch.device("cpu")
119-
elif backend == "nccl":
120-
assert is_nccl_available()
121-
from torch.distributed.distributed_c10d import ProcessGroupNCCL
122-
123-
backend_options = ProcessGroupNCCL.Options()
124-
backend_options._timeout = timeout
125-
126-
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
127-
backend_options)
128-
backend_type = ProcessGroup.BackendType.NCCL
129-
device = torch.device("cuda")
130-
elif backend == "hccl":
131-
from torch.distributed import is_hccl_available
132-
assert is_hccl_available()
133-
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
134-
backend_options = ProcessGroupHCCL.Options()
135-
backend_options._timeout = timeout
136-
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
137-
backend_options)
138-
device = torch.device("npu")
139-
backend_class._set_sequence_number_for_group()
140-
backend_type = ProcessGroup.BackendType.CUSTOM
141-
pg._register_backend(device, backend_type, backend_class)
142-
return pg
143-
else:
144-
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
145-
146-
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
147-
# implemented in the 2.5.1 version of PyTorch. But we need to set it
148-
# after the latest version is released.
149-
# pg._set_default_backend(backend_type)
150-
backend_class._set_sequence_number_for_group()
151-
152-
pg._register_backend(device, backend_type, backend_class)
153-
154-
return pg
155-
156-
15749
def parallel_config_get_dp_port(self) -> int:
15850
"""
15951
We might need to initialize process groups in multiple
@@ -171,7 +63,7 @@ def parallel_config_get_dp_port(self) -> int:
17163
return port
17264

17365

174-
def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
66+
def stateless_init_dp_group(self) -> "ProcessGroup":
17567
# TODO(Yizhou): Currently we have to set the backend to gloo
17668
# because in vllm.config.ParallelConfig.has_unfinished_dp the
17769
# device is set to cpu. We need to fix this in the future.
@@ -187,6 +79,21 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
18779
return dp_group
18880

18981

82+
def _init_data_parallel(self, vllm_config: VllmConfig):
83+
# Configure NPUs and stateless process group for data parallel.
84+
dp_rank = vllm_config.parallel_config.data_parallel_rank
85+
dp_size = vllm_config.parallel_config.data_parallel_size
86+
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
87+
88+
assert dp_size > 1
89+
assert 0 <= local_dp_rank <= dp_rank < dp_size
90+
91+
self.local_dp_rank = local_dp_rank
92+
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
93+
self.current_wave = 0
94+
95+
19096
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
97+
DPEngineCoreProc._init_data_parallel = _init_data_parallel
19198
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192-
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group
99+
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

vllm_ascend/platform.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import gc
1919
import logging
2020
import os
21+
from datetime import timedelta
2122
from typing import TYPE_CHECKING, Optional, Tuple
2223

2324
import torch
2425
import vllm.envs as envs
26+
from torch.distributed import ProcessGroup
27+
from torch.distributed.distributed_c10d import PrefixStore
2528
from vllm.logger import logger
2629
from vllm.platforms import Platform, PlatformEnum
2730

@@ -262,3 +265,45 @@ def get_piecewise_backend_cls(cls) -> str:
262265
Get piecewise backend class for piecewise graph.
263266
"""
264267
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
268+
269+
@classmethod
270+
def stateless_init_device_torch_dist_pg(
271+
cls,
272+
backend: str,
273+
prefix_store: PrefixStore,
274+
group_rank: int,
275+
group_size: int,
276+
timeout: timedelta,
277+
) -> ProcessGroup:
278+
from torch.distributed import is_hccl_available
279+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
280+
281+
assert is_hccl_available()
282+
283+
# TODO(Yizhou): The reason we need to set options while vllm does not
284+
# seems to be related to the version of PyTorch. In the latest version,
285+
# there is no need to set options. While in the older version, 2.5.1
286+
# specifically, we need to set options.
287+
options = ProcessGroup.Options(backend=backend)
288+
pg: ProcessGroup = ProcessGroup(
289+
prefix_store,
290+
group_rank,
291+
group_size,
292+
options,
293+
)
294+
295+
backend_options = ProcessGroupHCCL.Options()
296+
backend_options._timeout = timeout
297+
298+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
299+
backend_options)
300+
device = torch.device("npu")
301+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
302+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
303+
# after the latest version is released.
304+
# pg._set_default_backend(backend_type)
305+
backend_class._set_sequence_number_for_group()
306+
backend_type = ProcessGroup.BackendType.CUSTOM
307+
308+
pg._register_backend(device, backend_type, backend_class)
309+
return pg

vllm_ascend/worker/worker_v1.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def __init__(
7474
rank=rank,
7575
distributed_init_method=distributed_init_method,
7676
is_driver_worker=is_driver_worker)
77+
78+
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
79+
# vllm_ascend, we need to set the device id manually.
80+
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
81+
world_size = self.vllm_config.parallel_config.world_size
82+
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
83+
7784
# Try to import mindie_turbo to accelerate vLLM inference.
7885
try_register_lib(
7986
"mindie_turbo",
@@ -112,7 +119,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
112119

113120
def init_device(self):
114121
if self.device_config.device.type == "npu":
115-
self.device = torch.device(f"npu:{self.local_rank}")
122+
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
116123
NPUPlatform.set_device(self.device)
117124
NPUPlatform.empty_cache()
118125
self.init_npu_memory = NPUPlatform.mem_get_info()[0]

0 commit comments

Comments
 (0)