1
+ import torch
2
+ import threading
3
+
4
+ from torchrl .collectors .weight_update import RemoteWeightUpdaterBase
5
+ from torchrl .collectors .weight_update import LocalWeightUpdaterBase
6
+
7
+
8
+ VLLM_ERR = None
9
+ try :
10
+ import vllm
11
+ from vllm .worker .worker import Worker
12
+
13
+ _has_vllm = True
14
+ except ImportError as err :
15
+ _has_vllm = False
16
+ VLLM_ERR = err
17
+
18
+ # These utilities are copied from vLLM's example code.
19
+ def stateless_init_process_group (
20
+ master_address : str ,
21
+ master_port : int ,
22
+ rank : int ,
23
+ world_size : int ,
24
+ device : torch .device ,
25
+ ):
26
+ """
27
+ vLLM provides `StatelessProcessGroup` to create a process group
28
+ without considering the global process group in torch.distributed.
29
+ It is recommended to create `StatelessProcessGroup`, and then initialize
30
+ the data-plane communication (NCCL) between external (train processes)
31
+ and vLLM workers.
32
+ """
33
+ from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
34
+ from vllm .distributed .utils import StatelessProcessGroup
35
+
36
+ pg = StatelessProcessGroup .create (
37
+ host = master_address , port = master_port , rank = rank , world_size = world_size
38
+ )
39
+ pynccl = PyNcclCommunicator (pg , device = device )
40
+ return pynccl
41
+
42
+
43
+ if _has_vllm :
44
+ # I should use worker_extension_cls arg and not inherit from worker,
45
+ # but that is only available on main and not vLLM 0.7.3
46
+ class WorkerExtension (Worker ):
47
+ """
48
+ The class for vLLM's worker to inherit from.
49
+ By defining an extension class, the code can work no matter what is
50
+ the underlying worker class. This way, the code can be compatible
51
+ with both vLLM V0 and V1.
52
+ NOTE: we define this class in a separate module, and the main module
53
+ should pass the full qualified name as `worker_extension_cls` argument.
54
+ """
55
+
56
+ def init_weight_update_group (self , master_address , master_port ,
57
+ rank_offset , world_size ):
58
+ from vllm .distributed .parallel_state import get_world_group
59
+ # rank = get_world_group().rank + rank_offset
60
+ rank = rank_offset
61
+ self .model_update_group = stateless_init_process_group (
62
+ master_address ,
63
+ master_port ,
64
+ rank ,
65
+ world_size ,
66
+ self .device ,
67
+ )
68
+ self .version = torch .tensor ([0 ], device = "cuda" )
69
+
70
+ def update_weight (self , name , dtype , shape ):
71
+ weight = torch .empty (shape , dtype = dtype , device = "cuda" )
72
+ self .model_update_group .broadcast (weight ,
73
+ src = 0 ,
74
+ stream = torch .cuda .current_stream ())
75
+
76
+ self .model_runner .model .load_weights (weights = [(name , weight )])
77
+
78
+ del weight
79
+
80
+ def update_policy_version (self ):
81
+ self .model_update_group .broadcast (self .version ,
82
+ src = 0 ,
83
+ stream = torch .cuda .current_stream ())
84
+ torch .cuda .synchronize ()
85
+ # print(f"{self=} {self.model_runner.model=}")
86
+ self .policy_version = self .version
87
+
88
+ def check_weights_changed (self ):
89
+ """
90
+ Check if the weights are updated to 0.
91
+ """
92
+ weights_updated = True
93
+ for name , p in self .model_runner .model .named_parameters ():
94
+ weights_updated = weights_updated and torch .allclose (
95
+ p , torch .zeros_like (p ))
96
+ return weights_updated
97
+ else :
98
+ class WorkerExtension :
99
+ pass
100
+
101
+
102
+ class vLLMHFLocalWeightUpdater (LocalWeightUpdaterBase ):
103
+ def __init__ (self , master_address , master_port , model_metadata ):
104
+ print (f"{ master_address = } , { master_port = } " )
105
+ self .master_address = master_address
106
+ self .master_port = master_port
107
+ self .model_metadata = model_metadata
108
+ self .initialized_group = None
109
+
110
+ def _get_server_weights (self ):
111
+ return None
112
+
113
+ def _get_local_weights (self ):
114
+ # We don't implement this because we let vLLM's update_weights API handle everything for now
115
+ return None
116
+
117
+ def _maybe_map_weights (self , server_weights , local_weights ):
118
+ # vLLM update_weights function handles the mapping from huggingface
119
+ # so we don't implement this for now
120
+ return None
121
+
122
+ def _update_local_weights (self , local_weights , mapped_weights ):
123
+ llm = self .collector .policy ["generate" ].module
124
+ if self .initialized_group is None :
125
+ weight_sync_world_size = llm .llm_engine .parallel_config .tensor_parallel_size + 1
126
+ llm .collective_rpc (
127
+ "init_weight_update_group" ,
128
+ args = (self .master_address , self .master_port , 1 , weight_sync_world_size )
129
+ )
130
+ self .initialized_group = True
131
+
132
+ for k , (dtype , shape ) in self .model_metadata .items ():
133
+ llm .collective_rpc (
134
+ "update_weight" ,
135
+ args = (k , dtype , shape )
136
+ )
137
+
138
+ llm .collective_rpc ("update_policy_version" )
139
+ print ("done local update_weight" )
140
+
141
+ class ReadWriteLock :
142
+ """ A lock object that allows many simultaneous "read locks", but
143
+ only one "write lock." """
144
+
145
+ def __init__ (self ):
146
+ self ._read_ready = threading .Condition (threading .Lock ())
147
+ self ._readers = 0
148
+
149
+ def acquire_read (self ):
150
+ """ Acquire a read lock. Blocks only if a thread has
151
+ acquired the write lock. """
152
+ self ._read_ready .acquire ()
153
+ try :
154
+ self ._readers += 1
155
+ finally :
156
+ self ._read_ready .release ()
157
+
158
+ def release_read (self ):
159
+ """ Release a read lock. """
160
+ self ._read_ready .acquire ()
161
+ try :
162
+ self ._readers -= 1
163
+ if not self ._readers :
164
+ self ._read_ready .notifyAll ()
165
+ finally :
166
+ self ._read_ready .release ()
167
+
168
+ def acquire_write (self ):
169
+ """ Acquire a write lock. Blocks until there are no
170
+ acquired read or write locks. """
171
+ self ._read_ready .acquire ()
172
+ while self ._readers > 0 :
173
+ self ._read_ready .wait ()
174
+
175
+ def release_write (self ):
176
+ """ Release a write lock. """
177
+ self ._read_ready .release ()
178
+
179
+ class vLLMRemoteWeightUpdaterBase (RemoteWeightUpdaterBase ):
180
+ def __init__ (self , vllm_master_addresses , vllm_master_ports ):
181
+ super ().__init__ ()
182
+ from transformers import AutoModel
183
+ self .vllm_master_addresses = vllm_master_addresses
184
+ self .vllm_master_ports = vllm_master_ports
185
+ # state_dict = dict()
186
+ # for k, (dtype, shape) in model_metadata.items():
187
+ # self.state_dict[k] = torch.zeros(shape, dtype=dtype, device="cuda")
188
+ # self.state_dict = state_dict()
189
+ # self.state_dict_lock = ReadWriteLock()
190
+ self .vllm_comm_groups = dict ()
191
+ self .vllm_weight_versions = dict ()
192
+ # self.version = -1
193
+
194
+ def register_model_metadata (self , model_metadata ):
195
+ self .model_metadata = model_metadata
196
+ self .state_dict = dict ()
197
+ for k , (dtype , shape ) in model_metadata .items ():
198
+ self .state_dict [k ] = torch .zeros (shape , dtype = dtype , device = "cuda" )
199
+ self .state_dict_lock = ReadWriteLock ()
200
+ self .version = 0
201
+ self .version_tensor = torch .tensor ([0 ], device = "cuda" )
202
+
203
+ def acquire_state_dict_lock (self ):
204
+ self .state_dict_lock .acquire_write ()
205
+
206
+ def release_state_dict_lock (self ):
207
+ self .version += 1
208
+ self .version_tensor += 1
209
+ torch .cuda .synchronize ()
210
+ self .state_dict_lock .release_write ()
211
+
212
+ def all_worker_ids (self ):
213
+ return [i for i in range (len (self .collector ._remote_collectors ))]
214
+
215
+ def _get_server_weights (self ):
216
+ return self .state_dict
217
+
218
+ def _maybe_map_weights (self , server_weights ):
219
+ return server_weights
220
+
221
+ def _skip_update (self , worker_id ):
222
+ if self .version == 0 :
223
+ return True
224
+ if worker_id not in self .vllm_weight_versions :
225
+ return False
226
+ if self .vllm_weight_versions [worker_id ] == self .version :
227
+ print (f"skipping update for { worker_id = } , { self .version = } , { self .vllm_weight_versions [worker_id ]= } " )
228
+ return True
229
+ return False
230
+
231
+ def _init_model_update_group (self , worker_id ):
232
+ # here again, I want to grab the tp size from the vLLM worker... :(
233
+ # llm.llm_engine.parallel_config.tensor_parallel_size
234
+ vllm_tp_size = 1
235
+ weight_sync_world_size = vllm_tp_size + 1
236
+ print ("before stateless_init_process_group" )
237
+ model_update_group = stateless_init_process_group (
238
+ self .vllm_master_addresses [worker_id ],
239
+ self .vllm_master_ports [worker_id ],
240
+ 0 ,
241
+ weight_sync_world_size ,
242
+ torch .device ("cuda:0" ),
243
+ )
244
+ print ("after stateless_init_process_group" )
245
+ self .vllm_comm_groups [worker_id ] = model_update_group
246
+
247
+ def _sync_weights_with_worker (
248
+ self , worker_id : int , server_weights
249
+ ):
250
+ print (f"in _sync_weights_with_worker { worker_id } " , flush = True )
251
+ self .collector ._remote_collectors [worker_id ].update_policy_weights_ .remote ()
252
+ if worker_id not in self .vllm_comm_groups :
253
+ print ("init model update group" )
254
+ self ._init_model_update_group (worker_id )
255
+ print ("done init model update group" )
256
+ self .state_dict_lock .acquire_read ()
257
+ for i , k in enumerate (server_weights .keys ()):
258
+ # if i == 0:
259
+ # print(f"{server_weights[k][0]=}")
260
+ self .vllm_comm_groups [worker_id ].broadcast (server_weights [k ], src = 0 , stream = torch .cuda .current_stream ())
261
+ self .vllm_comm_groups [worker_id ].broadcast (self .version_tensor , src = 0 , stream = torch .cuda .current_stream ())
262
+ torch .cuda .synchronize ()
263
+ print (f"_sync_weights_with_worker done broadcast { worker_id } { self .version = } " )
264
+ self .vllm_weight_versions [worker_id ] = self .version
265
+ self .state_dict_lock .release_read ()
0 commit comments