17
17
# Adapted from vllm/model_executor/models/qwen2_vl.py
18
18
# This file is a part of the vllm-ascend project.
19
19
20
+ import torch
20
21
import torch
21
22
import vllm
22
23
import vllm .distributed
23
24
import vllm .envs as envs
24
25
from torch .distributed import ProcessGroup
26
+ from torch .distributed .distributed_c10d import (Backend , PrefixStore ,
27
+ _get_default_timeout ,
28
+ is_nccl_available )
29
+ from torch .distributed .rendezvous import rendezvous
25
30
from vllm .config import ParallelConfig
26
- from vllm .distributed .utils import \
27
- stateless_init_torch_distributed_process_group
28
-
29
- from vllm_ascend .utils import NullHandle , is_310p
30
31
31
32
32
33
def ascend_destroy_model_parallel ():
@@ -48,6 +49,112 @@ def ascend_destroy_model_parallel():
48
49
destory_ascend_model_parallel ()
49
50
50
51
52
+ def stateless_init_torch_distributed_process_group (
53
+ host : str , port : int , rank : int , world_size : int ,
54
+ backend : str ) -> ProcessGroup :
55
+ """
56
+ A replacement for `torch.distributed.init_process_group` that does not
57
+ pollute the global state. The created ProcessGroup object can be used for
58
+ some operations such as `allreduce`, because it does not depend on the
59
+ global rank. However, some operations such as `broadcast` cannot be used
60
+ because it depends on the global rank.
61
+
62
+ # TODO: ask for help from PyTorch team if we need the `broadcast` operation.
63
+
64
+ This function is useful when we are not sure about the total number of
65
+ processes in the process group. For example, we may have process
66
+ 1, 2, ..., 8 who want to communicate, and process 9 might be the same
67
+ process as process 1, or it might be a different process; process 10
68
+ might be the same process as process 5, or it might be a different process.
69
+ In this case, how can we reliably form a communication channel within
70
+ process 9 and 10, without affecting the communication channel within
71
+ process 1, 2, ..., 8?
72
+
73
+ One possible solution is to figure out if process 9 and 10 are the same
74
+ as process 1 and 5 beforehand, and then form a communication channel
75
+ based on the information, adjusting the ranks and world_size etc. However,
76
+ figuring out the information is not always easy, and it will interfere
77
+ with the main communication channel.
78
+
79
+ Our solution is to always form a communication channel with process 1, 2,
80
+ ..., 8, and then use this function to form another communication channel
81
+ with process 9 and 10. This way, regardless of whether process 9 and 10
82
+ are the same as process 1 and 5, the main communication channel is
83
+ always formed with process 1, 2, ..., 8, and the additional communication
84
+ channel is formed with process 9 and 10.
85
+ """
86
+ init_method = f"tcp://{ host } :{ port } "
87
+ backend = Backend (backend ) # it is basically string
88
+ timeout = _get_default_timeout (backend )
89
+
90
+ store , rank , world_size = next (
91
+ rendezvous (init_method , rank , world_size , timeout = timeout ))
92
+ store .set_timeout (timeout )
93
+
94
+ group_rank = rank
95
+ group_size = world_size
96
+
97
+ # Use a PrefixStore to avoid accidental overrides of keys used by
98
+ # different systems (e.g. RPC) in case the store is multi-tenant.
99
+ prefix_store = PrefixStore (init_method , store )
100
+
101
+ # TODO(Yizhou): The reason we need to set options while vllm does not
102
+ # seems to be related to the version of PyTorch. In the latest version,
103
+ # there is no need to set options. While in the older version, 2.5.1
104
+ # specifically, we need to set options.
105
+ options = ProcessGroup .Options (backend = backend )
106
+ pg : ProcessGroup = ProcessGroup (
107
+ prefix_store ,
108
+ group_rank ,
109
+ group_size ,
110
+ options ,
111
+ )
112
+ if backend == "gloo" :
113
+ from torch .distributed .distributed_c10d import ProcessGroupGloo
114
+ backend_class = ProcessGroupGloo (prefix_store ,
115
+ group_rank ,
116
+ group_size ,
117
+ timeout = timeout )
118
+ backend_type = ProcessGroup .BackendType .GLOO
119
+ device = torch .device ("cpu" )
120
+ elif backend == "nccl" :
121
+ assert is_nccl_available ()
122
+ from torch .distributed .distributed_c10d import ProcessGroupNCCL
123
+
124
+ backend_options = ProcessGroupNCCL .Options ()
125
+ backend_options ._timeout = timeout
126
+
127
+ backend_class = ProcessGroupNCCL (prefix_store , group_rank , group_size ,
128
+ backend_options )
129
+ backend_type = ProcessGroup .BackendType .NCCL
130
+ device = torch .device ("cuda" )
131
+ elif backend == "hccl" :
132
+ from torch .distributed import is_hccl_available
133
+ assert is_hccl_available ()
134
+ from torch_npu ._C ._distributed_c10d import ProcessGroupHCCL
135
+ backend_options = ProcessGroupHCCL .Options ()
136
+ backend_options ._timeout = timeout
137
+ backend_class = ProcessGroupHCCL (prefix_store , group_rank , group_size ,
138
+ backend_options )
139
+ device = torch .device ("npu" )
140
+ backend_class ._set_sequence_number_for_group ()
141
+ backend_type = ProcessGroup .BackendType .CUSTOM
142
+ pg ._register_backend (device , backend_type , backend_class )
143
+ return pg
144
+ else :
145
+ raise RuntimeError (f"Unsupported torch distributed backend: { backend } " )
146
+
147
+ # TODO(Yizhou): Like we mentioned above, _set_default_backend is not
148
+ # implemented in the 2.5.1 version of PyTorch. But we need to set it
149
+ # after the latest version is released.
150
+ # pg._set_default_backend(backend_type)
151
+ backend_class ._set_sequence_number_for_group ()
152
+
153
+ pg ._register_backend (device , backend_type , backend_class )
154
+
155
+ return pg
156
+
157
+
51
158
def parallel_config_get_dp_port (self ) -> int :
52
159
"""
53
160
We might need to initialize process groups in multiple
@@ -65,7 +172,7 @@ def parallel_config_get_dp_port(self) -> int:
65
172
return port
66
173
67
174
68
- def stateless_init_dp_group (self ) -> "ProcessGroup" :
175
+ def ascend_stateless_init_dp_group (self ) -> "ProcessGroup" :
69
176
# TODO(Yizhou): Currently we have to set the backend to gloo
70
177
# because in vllm.config.ParallelConfig.has_unfinished_dp the
71
178
# device is set to cpu. We need to fix this in the future.
@@ -83,71 +190,4 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
83
190
84
191
vllm .distributed .parallel_state .destroy_model_parallel = ascend_destroy_model_parallel
85
192
ParallelConfig .get_next_dp_init_port = parallel_config_get_dp_port
86
- ParallelConfig .stateless_init_dp_group = stateless_init_dp_group
87
-
88
-
89
- def communication_adaptation_310p ():
90
-
91
- def broadcast310p (tensor , src , group = None , async_op = False ):
92
- rank = torch .distributed .get_rank (group )
93
- world_size = torch .distributed .get_world_size (group )
94
- tensor_list = [torch .empty_like (tensor ) for _ in range (world_size )]
95
- tensor_list [rank ] = tensor
96
- torch .distributed .all_gather (tensor_list , tensor , group = group )
97
- tensor [...] = tensor_list [src ]
98
- if async_op :
99
- return NullHandle ()
100
- else :
101
- return None
102
-
103
- torch .distributed .broadcast = broadcast310p
104
- torch .distributed .distributed_c10d .broadcast = broadcast310p
105
-
106
- def all_reduce_wrapper_310p (fn ):
107
-
108
- def all_reduce (
109
- tensor ,
110
- op = torch .distributed .ReduceOp .SUM ,
111
- group = None ,
112
- async_op = False ,
113
- ):
114
- if tensor .dtype != torch .int64 :
115
- return fn (tensor , op , group , async_op )
116
- rank = torch .distributed .get_rank (group )
117
- world_size = torch .distributed .get_world_size (group )
118
- tensor_list = [torch .empty_like (tensor ) for _ in range (world_size )]
119
- tensor_list [rank ] = tensor
120
- torch .distributed .all_gather (tensor_list , tensor , group = group )
121
- if op == torch .distributed .ReduceOp .SUM :
122
- return torch .stack (tensor_list ).sum (0 )
123
- elif op == torch .distributed .ReduceOp .MAX :
124
- return torch .tensor (
125
- torch .stack (tensor_list ).cpu ().numpy ().max (0 ),
126
- device = tensor .device ,
127
- )
128
- else :
129
- raise RuntimeError (f"not implement op { op } " )
130
-
131
- return all_reduce
132
-
133
- torch .distributed .all_reduce = all_reduce_wrapper_310p (
134
- torch .distributed .all_reduce )
135
- torch .distributed .distributed_c10d .all_reduce = all_reduce_wrapper_310p (
136
- torch .distributed .distributed_c10d .all_reduce )
137
-
138
- def reduce_scatter_310p (output_tensor , input_tensor , group = None ):
139
- rank = torch .distributed .get_rank (group )
140
- world_size = torch .distributed .get_world_size (group )
141
- torch .distributed .all_reduce (input_tensor ,
142
- torch .distributed .ReduceOp .SUM ,
143
- group ,
144
- async_op = False )
145
- interval = input_tensor .shape [0 ] // world_size
146
- output_tensor [:] = input_tensor [rank * interval :(rank + 1 ) * interval ]
147
-
148
- torch .distributed ._reduce_scatter_base = reduce_scatter_310p
149
- torch .distributed .distributed_c10d ._reduce_scatter_base = reduce_scatter_310p
150
-
151
-
152
- if is_310p ():
153
- communication_adaptation_310p ()
193
+ ParallelConfig .stateless_init_dp_group = ascend_stateless_init_dp_group
0 commit comments