17
17
# Adapted from vllm/model_executor/models/qwen2_vl.py
18
18
# This file is a part of the vllm-ascend project.
19
19
20
+ import torch
20
21
import vllm
21
22
import vllm .distributed
22
23
import vllm .envs as envs
23
24
from torch .distributed import ProcessGroup
24
- from vllm .config import ParallelConfig , VllmConfig
25
- from vllm .distributed .utils import \
26
- stateless_init_torch_distributed_process_group
27
- from vllm .v1 .engine .core import DPEngineCoreProc
25
+ from torch .distributed .distributed_c10d import (Backend , PrefixStore ,
26
+ _get_default_timeout ,
27
+ is_nccl_available )
28
+ from torch .distributed .rendezvous import rendezvous
29
+ from vllm .config import ParallelConfig
28
30
29
31
30
32
def ascend_destroy_model_parallel ():
@@ -46,6 +48,112 @@ def ascend_destroy_model_parallel():
46
48
destory_ascend_model_parallel ()
47
49
48
50
51
+ def stateless_init_torch_distributed_process_group (
52
+ host : str , port : int , rank : int , world_size : int ,
53
+ backend : str ) -> ProcessGroup :
54
+ """
55
+ A replacement for `torch.distributed.init_process_group` that does not
56
+ pollute the global state. The created ProcessGroup object can be used for
57
+ some operations such as `allreduce`, because it does not depend on the
58
+ global rank. However, some operations such as `broadcast` cannot be used
59
+ because it depends on the global rank.
60
+
61
+ # TODO: ask for help from PyTorch team if we need the `broadcast` operation.
62
+
63
+ This function is useful when we are not sure about the total number of
64
+ processes in the process group. For example, we may have process
65
+ 1, 2, ..., 8 who want to communicate, and process 9 might be the same
66
+ process as process 1, or it might be a different process; process 10
67
+ might be the same process as process 5, or it might be a different process.
68
+ In this case, how can we reliably form a communication channel within
69
+ process 9 and 10, without affecting the communication channel within
70
+ process 1, 2, ..., 8?
71
+
72
+ One possible solution is to figure out if process 9 and 10 are the same
73
+ as process 1 and 5 beforehand, and then form a communication channel
74
+ based on the information, adjusting the ranks and world_size etc. However,
75
+ figuring out the information is not always easy, and it will interfere
76
+ with the main communication channel.
77
+
78
+ Our solution is to always form a communication channel with process 1, 2,
79
+ ..., 8, and then use this function to form another communication channel
80
+ with process 9 and 10. This way, regardless of whether process 9 and 10
81
+ are the same as process 1 and 5, the main communication channel is
82
+ always formed with process 1, 2, ..., 8, and the additional communication
83
+ channel is formed with process 9 and 10.
84
+ """
85
+ init_method = f"tcp://{ host } :{ port } "
86
+ backend = Backend (backend ) # it is basically string
87
+ timeout = _get_default_timeout (backend )
88
+
89
+ store , rank , world_size = next (
90
+ rendezvous (init_method , rank , world_size , timeout = timeout ))
91
+ store .set_timeout (timeout )
92
+
93
+ group_rank = rank
94
+ group_size = world_size
95
+
96
+ # Use a PrefixStore to avoid accidental overrides of keys used by
97
+ # different systems (e.g. RPC) in case the store is multi-tenant.
98
+ prefix_store = PrefixStore (init_method , store )
99
+
100
+ # TODO(Yizhou): The reason we need to set options while vllm does not
101
+ # seems to be related to the version of PyTorch. In the latest version,
102
+ # there is no need to set options. While in the older version, 2.5.1
103
+ # specifically, we need to set options.
104
+ options = ProcessGroup .Options (backend = backend )
105
+ pg : ProcessGroup = ProcessGroup (
106
+ prefix_store ,
107
+ group_rank ,
108
+ group_size ,
109
+ options ,
110
+ )
111
+ if backend == "gloo" :
112
+ from torch .distributed .distributed_c10d import ProcessGroupGloo
113
+ backend_class = ProcessGroupGloo (prefix_store ,
114
+ group_rank ,
115
+ group_size ,
116
+ timeout = timeout )
117
+ backend_type = ProcessGroup .BackendType .GLOO
118
+ device = torch .device ("cpu" )
119
+ elif backend == "nccl" :
120
+ assert is_nccl_available ()
121
+ from torch .distributed .distributed_c10d import ProcessGroupNCCL
122
+
123
+ backend_options = ProcessGroupNCCL .Options ()
124
+ backend_options ._timeout = timeout
125
+
126
+ backend_class = ProcessGroupNCCL (prefix_store , group_rank , group_size ,
127
+ backend_options )
128
+ backend_type = ProcessGroup .BackendType .NCCL
129
+ device = torch .device ("cuda" )
130
+ elif backend == "hccl" :
131
+ from torch .distributed import is_hccl_available
132
+ assert is_hccl_available ()
133
+ from torch_npu ._C ._distributed_c10d import ProcessGroupHCCL
134
+ backend_options = ProcessGroupHCCL .Options ()
135
+ backend_options ._timeout = timeout
136
+ backend_class = ProcessGroupHCCL (prefix_store , group_rank , group_size ,
137
+ backend_options )
138
+ device = torch .device ("npu" )
139
+ backend_class ._set_sequence_number_for_group ()
140
+ backend_type = ProcessGroup .BackendType .CUSTOM
141
+ pg ._register_backend (device , backend_type , backend_class )
142
+ return pg
143
+ else :
144
+ raise RuntimeError (f"Unsupported torch distributed backend: { backend } " )
145
+
146
+ # TODO(Yizhou): Like we mentioned above, _set_default_backend is not
147
+ # implemented in the 2.5.1 version of PyTorch. But we need to set it
148
+ # after the latest version is released.
149
+ # pg._set_default_backend(backend_type)
150
+ backend_class ._set_sequence_number_for_group ()
151
+
152
+ pg ._register_backend (device , backend_type , backend_class )
153
+
154
+ return pg
155
+
156
+
49
157
def parallel_config_get_dp_port (self ) -> int :
50
158
"""
51
159
We might need to initialize process groups in multiple
@@ -63,7 +171,7 @@ def parallel_config_get_dp_port(self) -> int:
63
171
return port
64
172
65
173
66
- def stateless_init_dp_group (self ) -> "ProcessGroup" :
174
+ def ascend_stateless_init_dp_group (self ) -> "ProcessGroup" :
67
175
# TODO(Yizhou): Currently we have to set the backend to gloo
68
176
# because in vllm.config.ParallelConfig.has_unfinished_dp the
69
177
# device is set to cpu. We need to fix this in the future.
@@ -79,21 +187,6 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
79
187
return dp_group
80
188
81
189
82
- def _init_data_parallel (self , vllm_config : VllmConfig ):
83
- # Configure NPUs and stateless process group for data parallel.
84
- dp_rank = vllm_config .parallel_config .data_parallel_rank
85
- dp_size = vllm_config .parallel_config .data_parallel_size
86
- local_dp_rank = vllm_config .parallel_config .data_parallel_rank_local
87
-
88
- assert dp_size > 1
89
- assert 0 <= local_dp_rank <= dp_rank < dp_size
90
-
91
- self .local_dp_rank = local_dp_rank
92
- self .dp_group = vllm_config .parallel_config .stateless_init_dp_group ()
93
- self .current_wave = 0
94
-
95
-
96
190
vllm .distributed .parallel_state .destroy_model_parallel = ascend_destroy_model_parallel
97
- DPEngineCoreProc ._init_data_parallel = _init_data_parallel
98
191
ParallelConfig .get_next_dp_init_port = parallel_config_get_dp_port
99
- ParallelConfig .stateless_init_dp_group = stateless_init_dp_group
192
+ ParallelConfig .stateless_init_dp_group = ascend_stateless_init_dp_group
0 commit comments