Skip to content

Commit d796e1d

Browse files
v0 param server (using collectives not object store)
ghstack-source-id: 70da726 Pull Request resolved: #2865
1 parent 04d70c1 commit d796e1d

File tree

4 files changed

+298
-7
lines changed

4 files changed

+298
-7
lines changed

torchrl/collectors/collectors.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ def cudagraph_mark_step_begin():
7676
"""Placeholder for missing cudagraph_mark_step_begin method."""
7777
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
7878

79+
try:
80+
import ray
81+
from ray.actor import ActorHandle
82+
83+
_has_ray = True
84+
except ImportError as err:
85+
_has_ray = False
86+
RAY_ERR = err
87+
7988

8089
_TIMEOUT = 1.0
8190
INSTANTIATE_TIMEOUT = 20
@@ -174,9 +183,12 @@ def remote_weight_updater(self) -> RemoteWeightUpdaterBase:
174183
@remote_weight_updater.setter
175184
def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None):
176185
if value is not None:
177-
value.register_collector(self)
178-
if value.collector is not self:
179-
raise RuntimeError("Failed to register collector.")
186+
if _has_ray and isinstance(value, ray.actor.ActorHandle):
187+
value.register_collector.remote(self)
188+
else:
189+
value.register_collector(self)
190+
if value.collector is not self:
191+
raise RuntimeError("Failed to register collector.")
180192
self._remote_weight_updater = value
181193

182194
def _get_policy_and_device(
@@ -308,7 +320,10 @@ def update_policy_weights_(
308320
if self.local_weight_updater is not None:
309321
self.local_weight_updater(policy_weights, **kwargs)
310322
if self.remote_weight_updater is not None:
311-
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
323+
if _has_ray and isinstance(self.remote_weight_updater, ray.actor.ActorHandle):
324+
ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs))
325+
else:
326+
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
312327
elif worker_ids is not None:
313328
raise TypeError("worker_ids was passed but remote_weight_updater was None.")
314329

torchrl/collectors/distributed/ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
759759
yield out_td
760760

761761
if self.update_after_each_batch or self.max_weight_update_interval > -1:
762-
self.update_policy_weights_(worker_ids=collector_index + 1)
762+
self.update_policy_weights_(worker_ids=collector_index)
763763

764764
# Schedule a new collection task
765765
future = collector.next.remote()
+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import torch
2+
import threading
3+
4+
from torchrl.collectors.weight_update import RemoteWeightUpdaterBase
5+
from torchrl.collectors.weight_update import LocalWeightUpdaterBase
6+
7+
8+
VLLM_ERR = None
9+
try:
10+
import vllm
11+
from vllm.worker.worker import Worker
12+
13+
_has_vllm = True
14+
except ImportError as err:
15+
_has_vllm = False
16+
VLLM_ERR = err
17+
18+
# These utilities are copied from vLLM's example code.
19+
def stateless_init_process_group(
20+
master_address: str,
21+
master_port: int,
22+
rank: int,
23+
world_size: int,
24+
device: torch.device,
25+
):
26+
"""
27+
vLLM provides `StatelessProcessGroup` to create a process group
28+
without considering the global process group in torch.distributed.
29+
It is recommended to create `StatelessProcessGroup`, and then initialize
30+
the data-plane communication (NCCL) between external (train processes)
31+
and vLLM workers.
32+
"""
33+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
34+
from vllm.distributed.utils import StatelessProcessGroup
35+
36+
pg = StatelessProcessGroup.create(
37+
host=master_address, port=master_port, rank=rank, world_size=world_size
38+
)
39+
pynccl = PyNcclCommunicator(pg, device=device)
40+
return pynccl
41+
42+
43+
if _has_vllm:
44+
# I should use worker_extension_cls arg and not inherit from worker,
45+
# but that is only available on main and not vLLM 0.7.3
46+
class WorkerExtension(Worker):
47+
"""
48+
The class for vLLM's worker to inherit from.
49+
By defining an extension class, the code can work no matter what is
50+
the underlying worker class. This way, the code can be compatible
51+
with both vLLM V0 and V1.
52+
NOTE: we define this class in a separate module, and the main module
53+
should pass the full qualified name as `worker_extension_cls` argument.
54+
"""
55+
56+
def init_weight_update_group(self, master_address, master_port,
57+
rank_offset, world_size):
58+
from vllm.distributed.parallel_state import get_world_group
59+
# rank = get_world_group().rank + rank_offset
60+
rank = rank_offset
61+
self.model_update_group = stateless_init_process_group(
62+
master_address,
63+
master_port,
64+
rank,
65+
world_size,
66+
self.device,
67+
)
68+
self.version = torch.tensor([0], device="cuda")
69+
70+
def update_weight(self, name, dtype, shape):
71+
weight = torch.empty(shape, dtype=dtype, device="cuda")
72+
self.model_update_group.broadcast(weight,
73+
src=0,
74+
stream=torch.cuda.current_stream())
75+
76+
self.model_runner.model.load_weights(weights=[(name, weight)])
77+
78+
del weight
79+
80+
def update_policy_version(self):
81+
self.model_update_group.broadcast(self.version,
82+
src=0,
83+
stream=torch.cuda.current_stream())
84+
torch.cuda.synchronize()
85+
# print(f"{self=} {self.model_runner.model=}")
86+
self.policy_version = self.version
87+
88+
def check_weights_changed(self):
89+
"""
90+
Check if the weights are updated to 0.
91+
"""
92+
weights_updated = True
93+
for name, p in self.model_runner.model.named_parameters():
94+
weights_updated = weights_updated and torch.allclose(
95+
p, torch.zeros_like(p))
96+
return weights_updated
97+
else:
98+
class WorkerExtension:
99+
pass
100+
101+
102+
class vLLMHFLocalWeightUpdater(LocalWeightUpdaterBase):
103+
def __init__(self, master_address, master_port, model_metadata):
104+
print(f"{master_address=}, {master_port=}")
105+
self.master_address = master_address
106+
self.master_port = master_port
107+
self.model_metadata = model_metadata
108+
self.initialized_group = None
109+
110+
def _get_server_weights(self):
111+
return None
112+
113+
def _get_local_weights(self):
114+
# We don't implement this because we let vLLM's update_weights API handle everything for now
115+
return None
116+
117+
def _maybe_map_weights(self, server_weights, local_weights):
118+
# vLLM update_weights function handles the mapping from huggingface
119+
# so we don't implement this for now
120+
return None
121+
122+
def _update_local_weights(self, local_weights, mapped_weights):
123+
llm = self.collector.policy["generate"].module
124+
if self.initialized_group is None:
125+
weight_sync_world_size = llm.llm_engine.parallel_config.tensor_parallel_size + 1
126+
llm.collective_rpc(
127+
"init_weight_update_group",
128+
args=(self.master_address, self.master_port, 1, weight_sync_world_size)
129+
)
130+
self.initialized_group = True
131+
132+
for k, (dtype, shape) in self.model_metadata.items():
133+
llm.collective_rpc(
134+
"update_weight",
135+
args=(k, dtype, shape)
136+
)
137+
138+
llm.collective_rpc("update_policy_version")
139+
print("done local update_weight")
140+
141+
class ReadWriteLock:
142+
""" A lock object that allows many simultaneous "read locks", but
143+
only one "write lock." """
144+
145+
def __init__(self):
146+
self._read_ready = threading.Condition(threading.Lock())
147+
self._readers = 0
148+
149+
def acquire_read(self):
150+
""" Acquire a read lock. Blocks only if a thread has
151+
acquired the write lock. """
152+
self._read_ready.acquire()
153+
try:
154+
self._readers += 1
155+
finally:
156+
self._read_ready.release()
157+
158+
def release_read(self):
159+
""" Release a read lock. """
160+
self._read_ready.acquire()
161+
try:
162+
self._readers -= 1
163+
if not self._readers:
164+
self._read_ready.notifyAll()
165+
finally:
166+
self._read_ready.release()
167+
168+
def acquire_write(self):
169+
""" Acquire a write lock. Blocks until there are no
170+
acquired read or write locks. """
171+
self._read_ready.acquire()
172+
while self._readers > 0:
173+
self._read_ready.wait()
174+
175+
def release_write(self):
176+
""" Release a write lock. """
177+
self._read_ready.release()
178+
179+
class vLLMRemoteWeightUpdaterBase(RemoteWeightUpdaterBase):
180+
def __init__(self, vllm_master_addresses, vllm_master_ports):
181+
super().__init__()
182+
from transformers import AutoModel
183+
self.vllm_master_addresses = vllm_master_addresses
184+
self.vllm_master_ports = vllm_master_ports
185+
# state_dict = dict()
186+
# for k, (dtype, shape) in model_metadata.items():
187+
# self.state_dict[k] = torch.zeros(shape, dtype=dtype, device="cuda")
188+
# self.state_dict = state_dict()
189+
# self.state_dict_lock = ReadWriteLock()
190+
self.vllm_comm_groups = dict()
191+
self.vllm_weight_versions = dict()
192+
# self.version = -1
193+
194+
def register_model_metadata(self, model_metadata):
195+
self.model_metadata = model_metadata
196+
self.state_dict = dict()
197+
for k, (dtype, shape) in model_metadata.items():
198+
self.state_dict[k] = torch.zeros(shape, dtype=dtype, device="cuda")
199+
self.state_dict_lock = ReadWriteLock()
200+
self.version = 0
201+
self.version_tensor = torch.tensor([0], device="cuda")
202+
203+
def acquire_state_dict_lock(self):
204+
self.state_dict_lock.acquire_write()
205+
206+
def release_state_dict_lock(self):
207+
self.version += 1
208+
self.version_tensor += 1
209+
torch.cuda.synchronize()
210+
self.state_dict_lock.release_write()
211+
212+
def all_worker_ids(self):
213+
return [i for i in range(len(self.collector._remote_collectors))]
214+
215+
def _get_server_weights(self):
216+
return self.state_dict
217+
218+
def _maybe_map_weights(self, server_weights):
219+
return server_weights
220+
221+
def _skip_update(self, worker_id):
222+
if self.version == 0:
223+
return True
224+
if worker_id not in self.vllm_weight_versions:
225+
return False
226+
if self.vllm_weight_versions[worker_id] == self.version:
227+
print(f"skipping update for {worker_id=}, {self.version=}, {self.vllm_weight_versions[worker_id]=}")
228+
return True
229+
return False
230+
231+
def _init_model_update_group(self, worker_id):
232+
# here again, I want to grab the tp size from the vLLM worker... :(
233+
# llm.llm_engine.parallel_config.tensor_parallel_size
234+
vllm_tp_size = 1
235+
weight_sync_world_size = vllm_tp_size + 1
236+
print("before stateless_init_process_group")
237+
model_update_group = stateless_init_process_group(
238+
self.vllm_master_addresses[worker_id],
239+
self.vllm_master_ports[worker_id],
240+
0,
241+
weight_sync_world_size,
242+
torch.device("cuda:0"),
243+
)
244+
print("after stateless_init_process_group")
245+
self.vllm_comm_groups[worker_id] = model_update_group
246+
247+
def _sync_weights_with_worker(
248+
self, worker_id: int, server_weights
249+
):
250+
print(f"in _sync_weights_with_worker {worker_id}", flush=True)
251+
self.collector._remote_collectors[worker_id].update_policy_weights_.remote()
252+
if worker_id not in self.vllm_comm_groups:
253+
print("init model update group")
254+
self._init_model_update_group(worker_id)
255+
print("done init model update group")
256+
self.state_dict_lock.acquire_read()
257+
for i, k in enumerate(server_weights.keys()):
258+
# if i == 0:
259+
# print(f"{server_weights[k][0]=}")
260+
self.vllm_comm_groups[worker_id].broadcast(server_weights[k], src=0, stream=torch.cuda.current_stream())
261+
self.vllm_comm_groups[worker_id].broadcast(self.version_tensor, src=0, stream=torch.cuda.current_stream())
262+
torch.cuda.synchronize()
263+
print(f"_sync_weights_with_worker done broadcast {worker_id} {self.version=}")
264+
self.vllm_weight_versions[worker_id] = self.version
265+
self.state_dict_lock.release_read()

torchrl/modules/llm/vllm_policy.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def tokenize(td):
277277
module_dict["generate"] = Mod(
278278
model,
279279
method="generate",
280-
method_kwargs={"sampling_params": sampling_params},
280+
method_kwargs={"sampling_params": sampling_params, 'use_tqdm': False},
281281
in_keys=in_keys,
282282
out_keys=["tokens_out"],
283283
out_to_in_map=True,
@@ -426,6 +426,15 @@ def move_input(td):
426426
out_to_in_map=True,
427427
strict=True,
428428
)
429+
430+
def add_policy_version(td):
431+
if hasattr(model.llm_engine.model_executor.driver_worker.worker, "policy_version"):
432+
td["policy_version"] = NonTensorData(model.llm_engine.model_executor.driver_worker.worker.policy_version.item())
433+
else:
434+
td["policy_version"] = NonTensorData(0)
435+
return td
436+
437+
module_dict["add_policy_version"] = add_policy_version
429438

430439
def get_output_tokens_and_log_probs(td, padding_value=padding_value):
431440
td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"])
@@ -446,7 +455,7 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
446455
padded_values = tokens_response_td["tokens_response"] == padding_value
447456
if padded_values.any():
448457
lps = tokens_response_td["log_probs"]
449-
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
458+
lps = torch.where(expand_as_right(~padded_values, lps), lps, 1.0)
450459
tokens_response_td["log_probs"] = lps
451460
td.update(tokens_response_td)
452461
return td
@@ -462,13 +471,15 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
462471
("tokens_in", "input_ids"),
463472
("tokens_in", "attention_mask"),
464473
"text_response",
474+
"policy_version",
465475
]
466476
out_keys = [
467477
"log_probs",
468478
"tokens_response",
469479
token_key,
470480
attention_mask_key,
471481
"text_response",
482+
"policy_version",
472483
]
473484

474485
def format_td(td):

0 commit comments

Comments
 (0)