From c285ef852ce76ea982bb73299b7f74760e0c6e13 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 21 Mar 2025 08:29:11 -0700 Subject: [PATCH 1/5] v0 param server (using collectives not object store) [ghstack-poisoned] --- param_server_weight_updater.py | 403 ++++++++++++++++++++++++++ torchrl/collectors/collectors.py | 12 +- torchrl/collectors/distributed/ray.py | 4 +- torchrl/collectors/weight_update.py | 14 +- torchrl/modules/llm/vllm_policy.py | 7 + 5 files changed, 434 insertions(+), 6 deletions(-) create mode 100644 param_server_weight_updater.py diff --git a/param_server_weight_updater.py b/param_server_weight_updater.py new file mode 100644 index 00000000000..08823703bd8 --- /dev/null +++ b/param_server_weight_updater.py @@ -0,0 +1,403 @@ +import ray + +from argparse import ArgumentParser +from functools import partial + +import torch +from datasets import load_dataset +from tensordict import TensorDict +from torch.utils.data import DataLoader +from torchrl.collectors.weight_update import RayRemoteWeightUpdater +from transformers import AutoTokenizer, AutoModel +from vllm import LLM + +from vllm.utils import get_ip, get_open_port + +from vllm.worker.worker import Worker + +from torchrl.collectors.distributed import RayCollector +from torchrl.envs import LLMEnv +from torchrl.modules import from_vllm + +parser = ArgumentParser() +parser.add_argument("--dataset", type=str, default="gsm8k") +parser.add_argument("--batch_size", type=int, default=4) +parser.add_argument("--epochs", type=int, default=10) +parser.add_argument("--repeats", type=int, default=10) +parser.add_argument("--steps_per_batch", type=int, default=16) +parser.add_argument("--optim_batch_size", type=int, default=4) + +def stateless_init_process_group( + master_address: str, + master_port: int, + rank: int, + world_size: int, + device: torch.device, +): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +# I should use worker_extension_cls arg and not inherit from worker, +# but that is only available on main and not 0.7.3 +class WorkerExtension(Worker): + """ + The class for vLLM's worker to inherit from. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +def make_policy(): + inference_model = LLM( + "facebook/opt-125m", + enforce_eager=True, + # change to worker_extension_cls when available in stable release + worker_cls=WorkerExtension, + ) + + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + policy = from_vllm( + inference_model, tokenizer=tokenizer, from_text=False, generate=True, return_log_probs=True, generate_kwargs={"temperature": 0.0}) + return policy + + +def make_env(dataset, batch_size): + dataset = load_dataset(dataset, "main") + train_dataset = dataset["train"] + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + # Env + dataloader = DataLoader( # noqa: TOR401 + train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + env = LLMEnv.from_dataloader( + dataloader=dataloader, + tokenizer=tokenizer, + str2str=True, + batch_size=(args.batch_size * args.repeats,), + repeats=args.repeats, ) + return env + + +def collate_fn(batch): + batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch]) + batch.rename_key_("question", "text") + return batch + +@ray.remote(num_cpus=1, num_gpus=1) +class TrainerActor: + def __init__(self, env_vars): + import os + import torch + import torch.distributed + from torch.distributed._composable.fsdp import fully_shard + + torch.cuda.set_device(torch.device('cuda', 0)) + + print(os.environ["CUDA_VISIBLE_DEVICES"]) + + for var in env_vars: + os.environ[var] = str(env_vars[var]) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0')) + print("initialized process group") + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + print(world_size, rank) + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + + # self.param_server_comm_group = None + # if self.rank == 0: + # self.param_server_comm_group = torch.distributed.new_group(ranks=[0, self.world_size - 1], use_local_synchronization=True) + + # hold back one rank for the parameter server + self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1))) + self.comm_group = torch.distributed.new_group(ranks=[0, 2]) + self.device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(self.fsdp_group, device_type="cuda") + + self.model = AutoModel.from_pretrained("facebook/opt-125m").cuda() + + fully_shard(self.model, mesh=self.device_mesh) + + def register_parameter_server(self, param_server): + # assert self.rank == 0 + self.param_server = param_server + + def send_weights_to_param_server(self): + # assert(hasattr(self, "param_server")) + for k, v in self.model.state_dict().items(): + replicated_v = v.full_tensor() + # dst is global rank, can switch to group_dst arg if not 2.5.1 + if self.rank == 0: + # print(f"sending {k}, {replicated_v.nbytes}") + handle = self.param_server.receive_from_train.remote(k) + torch.distributed.send(replicated_v, dst=2) + # ray.get(handle) + + def zero_(self): + sd = self.model.state_dict() + for k, v in sd.items(): + sd[k] = v.data.zero_() + + def train(self): + import time + for _ in range(1): + # actually run train loop + # ... + self.zero_() + torch.distributed.barrier(group=self.fsdp_group) + print("done barrier!") + # if self.rank == 0: + # print("starting send weights") + self.send_weights_to_param_server() + torch.distributed.barrier(group=self.fsdp_group) + + +from torchrl.collectors.weight_update import RemoteWeightUpdaterBase + +@ray.remote(num_cpus=1, num_gpus=1) +class vLLMParameterServer(RemoteWeightUpdaterBase): + def __init__(self, env_vars): + import os + import torch + import torch.distributed + + torch.cuda.set_device(torch.device('cuda', 0)) + + for var in env_vars: + os.environ[var] = str(env_vars[var]) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0')) + print("initialized process group") + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + print(world_size, rank) + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + assert self.rank == self.world_size - 1 + + self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1))) + self.comm_group = torch.distributed.new_group(ranks=[0, 2]) + + # self.param_server_trainer_comm_group = torch.distributed.new_group(ranks=[0, self.world_size - 1], use_local_synchronization=True) + + self.param_server_vllm_comm_groups = dict() + + # Having the state_dict fit on one GPU will not scale + self.state_dict = AutoModel.from_pretrained("facebook/opt-125m").cuda().eval().state_dict() + + self.lock = torch.multiprocessing.Lock() + self.version = 0 + + print(self.state_dict.keys()) + + def receive_from_train(self, k): + # with self.lock: + # src is global rank, an change to group_src once not 2.5.1 + # print(f"receiving {k}") + torch.distributed.recv(self.state_dict[k], src=0) + # self.version += 1 + # print(f"received {k} {self.state_dict[k].flatten()[0]}") + + def _init_model_update_group(self, worker_id): + print("in init model update group", worker_id) + master_address, master_port = get_ip(), get_open_port() + print(master_address, master_port) + # FIXME!!!! This needs to be grabbed from each remote collector + vllm_tp_size = 1 + weight_sync_world_size = vllm_tp_size + 1 + print("calling collective_rpc") + self.collector._remote_collectors[worker_id].call_policy_method.remote( + "collective_rpc", + ("init_weight_update_group",), + {'args': (master_address, master_port, 1, weight_sync_world_size)} + ) + print("done collective_rpc") + model_update_group = stateless_init_process_group( + master_address, + master_port, + 0, + weight_sync_world_size, + torch.device("cuda:0"), + ) + print("done stateless init process group") + self.param_server_vllm_comm_groups[worker_id] = model_update_group + + + def _sync_weights_with_worker( + self, worker_id: int, server_weights + ): + if worker_id not in self.param_server_vllm_comm_groups: + self._init_model_update_group(worker_id) + handles = [] + for i, (k, v) in enumerate(server_weights.items()): + handle = self.collector._remote_collectors[worker_id].call_policy_method.remote( + "collective_rpc", + ("update_weight",), + {'args': (k, v.dtype, v.shape)} + ) + handles.append(handle) + # self.collector._remote_collectors[worker_id].collective_rpc.remote("update_weight", args=(k, v.dtype, v.shape)) + self.param_server_vllm_comm_groups[worker_id].broadcast(server_weights[k], src=0, stream=torch.cuda.current_stream()) + handle = self.collector._remote_collectors[worker_id].call_policy_method.remote( + "collective_rpc", + ("check_weights_changed",), + {}, + ) + + print(f"weights changed {ray.get(handle)}") + # probably no need barrier because subsequent gpu work should be serialized + # self._batches_since_weight_update[worker_id] = 0 + + def _get_server_weights(self): + print("in _get_server_weights") + with self.lock: + return self.state_dict + + def _maybe_map_weights(self, server_weights): + # This is making a design choice that weight mapping always happens on the parameter servver + # I don't think we should make this design choice so early. + return server_weights + + def all_worker_ids(self): + return [0] + + def _skip_update(self, worker_id: int) -> bool: + pass + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.state_dict.items(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +def _create_trainer_group(worker_cls, param_server_cls, world_size: int): + addr, port = get_ip(), get_open_port() + trainer_workers = [] + fsdp_world_size = world_size - 1 + for i in range(fsdp_world_size): + env_vars = { + "RANK": str(i), + "WORLD_SIZE": world_size, + "MASTER_ADDR": str(addr), + "MASTER_PORT": str(port), + } + worker = worker_cls.remote(env_vars) + trainer_workers.append(worker) + + env_vars = { + "RANK": str(world_size - 1), + "WORLD_SIZE": world_size, + "MASTER_ADDR": str(addr), + "MASTER_PORT": str(port), + } + parameter_server = param_server_cls.remote(env_vars) + trainer_workers[0].register_parameter_server.remote(parameter_server) + trainer_workers[1].register_parameter_server.remote(parameter_server) + return trainer_workers, parameter_server + + +if __name__ == "__main__": + args = parser.parse_args() + + remote_configs = { + "num_cpus": 1, + "num_gpus": 1, + "memory": 2 * 1024**3, + } + + ray.init(num_cpus=4, num_gpus=4) + + trainer_workers, parameter_server = _create_trainer_group(TrainerActor, vLLMParameterServer, 3) + + handles = [] + for trainer_worker in trainer_workers: + handles.append(trainer_worker.train.remote()) + + + print(f"param server weights updated {ray.get(parameter_server.check_weights_changed.remote())}") + + make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset) + collector = RayCollector( + [make_env_parsed], + policy_factory=make_policy, + frames_per_batch=40, + total_frames=200, + remote_configs=remote_configs, + remote_weights_updater=parameter_server, + update_after_each_batch=True, + ) + print("done collector init") + + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + + for i, data in enumerate(collector): + print(tokenizer.decode(data["tokens"][0].squeeze())) + print(tokenizer.decode(data["tokens_response"][0].squeeze())) + if i == 1: + break + collector.shutdown() \ No newline at end of file diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 38e88779af3..63fb78f96bb 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -306,9 +306,10 @@ def update_policy_weights_( """ if self.local_weight_updater is not None: - self.local_weight_updater(policy_weights, **kwargs) + self.local_weights_updater(policy_weights, **kwargs) if self.remote_weight_updater is not None: - self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs) + import ray + ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs)) elif worker_ids is not None: raise TypeError("worker_ids was passed but remote_weight_updater was None.") @@ -861,6 +862,13 @@ def __init__( self.local_weight_updater = local_weight_updater self.remote_weight_updater = remote_weight_updater + def call_policy_method(self, method: str, args, kwargs): + # I want world where I don't have to do this to call a method on the + # vllm policy that is owned by the remote collector + + result = getattr(self.policy['generate'].module, method)(*args, **kwargs) + return result + @property def _traj_pool(self): pool = getattr(self, "_traj_pool_val", None) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index f7517b13143..25958e31e7f 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -646,6 +646,7 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): + print(f"{self._sync=}") def proc(data): if self.split_trajs: data = split_trajectories(data) @@ -759,8 +760,9 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td if self.update_after_each_batch or self.max_weight_update_interval > -1: - self.update_policy_weights_(worker_ids=collector_index + 1) + self.update_policy_weights_(worker_ids=collector_index) + print("done updating policy weights") # Schedule a new collection task future = collector.next.remote() pending_tasks[future] = collector_index diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 46732057ad6..b6232f20a5b 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -13,6 +13,8 @@ from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase +# from torchrl.collectors import DataCollectorBase + Policy = TypeVar("Policy", bound=TensorDictModuleBase) @@ -47,7 +49,7 @@ class LocalWeightUpdaterBase(metaclass=abc.ABCMeta): _collector_wr: Any = None - def register_collector(self, collector: DataCollectorBase): # noqa + def register_collector(self, collector): # noqa """Register a collector in the updater. Once registered, the updater will not accept another collector. @@ -61,7 +63,7 @@ def register_collector(self, collector: DataCollectorBase): # noqa self._collector_wr = weakref.ref(collector) @property - def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa + def collector(self): # noqa return self._collector_wr() if self._collector_wr is not None else None @abstractmethod @@ -134,7 +136,7 @@ class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta): _collector_wr: Any = None - def register_collector(self, collector: DataCollectorBase): # noqa + def register_collector(self, collector): # noqa """Register a collector in the updater. Once registered, the updater will not accept another collector. @@ -148,7 +150,11 @@ def register_collector(self, collector: DataCollectorBase): # noqa self._collector_wr = weakref.ref(collector) @property +<<<<<<< HEAD def collector(self) -> torch.collector.DataCollectorBase: # noqa +======= + def collector(self): +>>>>>>> 3917c0752 (v0 param server (using collectives not object store)) return self._collector_wr() if self._collector_wr is not None else None @abstractmethod @@ -184,12 +190,14 @@ def update_weights( weights: TensorDictBase | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, ): + print(f"in update_weights {worker_ids}") if weights is None: # Get the weights on server (local) server_weights = self._get_server_weights() else: server_weights = weights + self._maybe_map_weights(server_weights) # Get the remote weights (inference workers) diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index a810fe98c1e..688ec7855c6 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -415,6 +415,13 @@ def move_input(td): generate_kwargs.setdefault("logprobs", return_log_probs) sampling_params = SamplingParams(**generate_kwargs) + # def print_weights(td): + # for i, (name, param) in enumerate(model.llm_engine.model_executor.driver_worker.worker.model_runner.model.named_parameters()): + # if i == 0: + # print(f"Model parameters: {name} {param[0]}") + # return td + + # module_dict["print_weights"] = print_weights module_dict["generate"] = Mod( model, method="generate", From 839cf0a9ba9192ac71c1554f640c1a64ac6c6a96 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 21 Mar 2025 22:50:08 -0700 Subject: [PATCH 2/5] Update on "v0 param server (using collectives not object store)" [ghstack-poisoned] --- param_server_weight_updater.py | 232 +++++------------------ torchrl/collectors/collectors.py | 20 +- torchrl/collectors/vllm_weight_update.py | 179 +++++++++++++++++ torchrl/collectors/weight_update.py | 4 - 4 files changed, 241 insertions(+), 194 deletions(-) create mode 100644 torchrl/collectors/vllm_weight_update.py diff --git a/param_server_weight_updater.py b/param_server_weight_updater.py index 08823703bd8..a627a88f4cd 100644 --- a/param_server_weight_updater.py +++ b/param_server_weight_updater.py @@ -13,12 +13,12 @@ from vllm.utils import get_ip, get_open_port -from vllm.worker.worker import Worker - from torchrl.collectors.distributed import RayCollector from torchrl.envs import LLMEnv from torchrl.modules import from_vllm +from torchrl.collectors.vllm_weight_update import vLLMHFLocalWeightUpdater, vLLMRemoteWeightUpdaterBase, WorkerExtension + parser = ArgumentParser() parser.add_argument("--dataset", type=str, default="gsm8k") parser.add_argument("--batch_size", type=int, default=4) @@ -27,74 +27,6 @@ parser.add_argument("--steps_per_batch", type=int, default=16) parser.add_argument("--optim_batch_size", type=int, default=4) -def stateless_init_process_group( - master_address: str, - master_port: int, - rank: int, - world_size: int, - device: torch.device, -): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - - pg = StatelessProcessGroup.create( - host=master_address, port=master_port, rank=rank, world_size=world_size - ) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl - - -# I should use worker_extension_cls arg and not inherit from worker, -# but that is only available on main and not 0.7.3 -class WorkerExtension(Worker): - """ - The class for vLLM's worker to inherit from. - By defining an extension class, the code can work no matter what is - the underlying worker class. This way, the code can be compatible - with both vLLM V0 and V1. - NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_extension_cls` argument. - """ - - def init_weight_update_group(self, master_address, master_port, - rank_offset, world_size): - from vllm.distributed.parallel_state import get_world_group - rank = get_world_group().rank + rank_offset - self.model_update_group = stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - self.device, - ) - - def update_weight(self, name, dtype, shape): - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, - src=0, - stream=torch.cuda.current_stream()) - - self.model_runner.model.load_weights(weights=[(name, weight)]) - - del weight - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) - return weights_updated - def make_policy(): inference_model = LLM( @@ -140,15 +72,13 @@ def collate_fn(batch): @ray.remote(num_cpus=1, num_gpus=1) class TrainerActor: - def __init__(self, env_vars): + def __init__(self, model, env_vars): import os import torch import torch.distributed from torch.distributed._composable.fsdp import fully_shard torch.cuda.set_device(torch.device('cuda', 0)) - - print(os.environ["CUDA_VISIBLE_DEVICES"]) for var in env_vars: os.environ[var] = str(env_vars[var]) @@ -163,33 +93,30 @@ def __init__(self, env_vars): self.rank = int(os.environ["RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) - # self.param_server_comm_group = None - # if self.rank == 0: - # self.param_server_comm_group = torch.distributed.new_group(ranks=[0, self.world_size - 1], use_local_synchronization=True) # hold back one rank for the parameter server self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1))) - self.comm_group = torch.distributed.new_group(ranks=[0, 2]) self.device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(self.fsdp_group, device_type="cuda") - self.model = AutoModel.from_pretrained("facebook/opt-125m").cuda() + self.model = AutoModel.from_pretrained(model).cuda() fully_shard(self.model, mesh=self.device_mesh) def register_parameter_server(self, param_server): - # assert self.rank == 0 + assert self.rank == 0 self.param_server = param_server def send_weights_to_param_server(self): - # assert(hasattr(self, "param_server")) + if self.rank == 0: + ray.get(self.param_server.acquire_state_dict_lock.remote()) + self.param_server.receive_from_trainer.remote() for k, v in self.model.state_dict().items(): replicated_v = v.full_tensor() - # dst is global rank, can switch to group_dst arg if not 2.5.1 if self.rank == 0: - # print(f"sending {k}, {replicated_v.nbytes}") - handle = self.param_server.receive_from_train.remote(k) + # dst is global rank, can switch to group_dst arg if not 2.5.1 torch.distributed.send(replicated_v, dst=2) - # ray.get(handle) + if self.rank == 0: + ray.get(self.param_server.release_state_dict_lock.remote()) def zero_(self): sd = self.model.state_dict() @@ -203,18 +130,14 @@ def train(self): # ... self.zero_() torch.distributed.barrier(group=self.fsdp_group) - print("done barrier!") - # if self.rank == 0: - # print("starting send weights") self.send_weights_to_param_server() torch.distributed.barrier(group=self.fsdp_group) -from torchrl.collectors.weight_update import RemoteWeightUpdaterBase - @ray.remote(num_cpus=1, num_gpus=1) -class vLLMParameterServer(RemoteWeightUpdaterBase): - def __init__(self, env_vars): +class vLLMParameterServer(vLLMRemoteWeightUpdaterBase): + def __init__(self, model, vllm_master_address, vllm_master_port, env_vars): + super().__init__(model, vllm_master_address, vllm_master_port) import os import torch import torch.distributed @@ -226,100 +149,16 @@ def __init__(self, env_vars): if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0')) - print("initialized process group") - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - print(world_size, rank) self.rank = int(os.environ["RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) assert self.rank == self.world_size - 1 self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1))) - self.comm_group = torch.distributed.new_group(ranks=[0, 2]) - - # self.param_server_trainer_comm_group = torch.distributed.new_group(ranks=[0, self.world_size - 1], use_local_synchronization=True) - - self.param_server_vllm_comm_groups = dict() - - # Having the state_dict fit on one GPU will not scale - self.state_dict = AutoModel.from_pretrained("facebook/opt-125m").cuda().eval().state_dict() - - self.lock = torch.multiprocessing.Lock() - self.version = 0 - - print(self.state_dict.keys()) - - def receive_from_train(self, k): - # with self.lock: - # src is global rank, an change to group_src once not 2.5.1 - # print(f"receiving {k}") - torch.distributed.recv(self.state_dict[k], src=0) - # self.version += 1 - # print(f"received {k} {self.state_dict[k].flatten()[0]}") - - def _init_model_update_group(self, worker_id): - print("in init model update group", worker_id) - master_address, master_port = get_ip(), get_open_port() - print(master_address, master_port) - # FIXME!!!! This needs to be grabbed from each remote collector - vllm_tp_size = 1 - weight_sync_world_size = vllm_tp_size + 1 - print("calling collective_rpc") - self.collector._remote_collectors[worker_id].call_policy_method.remote( - "collective_rpc", - ("init_weight_update_group",), - {'args': (master_address, master_port, 1, weight_sync_world_size)} - ) - print("done collective_rpc") - model_update_group = stateless_init_process_group( - master_address, - master_port, - 0, - weight_sync_world_size, - torch.device("cuda:0"), - ) - print("done stateless init process group") - self.param_server_vllm_comm_groups[worker_id] = model_update_group - - - def _sync_weights_with_worker( - self, worker_id: int, server_weights - ): - if worker_id not in self.param_server_vllm_comm_groups: - self._init_model_update_group(worker_id) - handles = [] - for i, (k, v) in enumerate(server_weights.items()): - handle = self.collector._remote_collectors[worker_id].call_policy_method.remote( - "collective_rpc", - ("update_weight",), - {'args': (k, v.dtype, v.shape)} - ) - handles.append(handle) - # self.collector._remote_collectors[worker_id].collective_rpc.remote("update_weight", args=(k, v.dtype, v.shape)) - self.param_server_vllm_comm_groups[worker_id].broadcast(server_weights[k], src=0, stream=torch.cuda.current_stream()) - handle = self.collector._remote_collectors[worker_id].call_policy_method.remote( - "collective_rpc", - ("check_weights_changed",), - {}, - ) - - print(f"weights changed {ray.get(handle)}") - # probably no need barrier because subsequent gpu work should be serialized - # self._batches_since_weight_update[worker_id] = 0 - def _get_server_weights(self): - print("in _get_server_weights") - with self.lock: - return self.state_dict - - def _maybe_map_weights(self, server_weights): - # This is making a design choice that weight mapping always happens on the parameter servver - # I don't think we should make this design choice so early. - return server_weights - - def all_worker_ids(self): - return [0] + def receive_from_trainer(self): + for k, v in self.state_dict.items(): + torch.distributed.recv(v, src=0) def _skip_update(self, worker_id: int) -> bool: pass @@ -335,7 +174,15 @@ def check_weights_changed(self): return weights_updated -def _create_trainer_group(worker_cls, param_server_cls, world_size: int): + +def _create_trainer_group( + worker_cls, + param_server_cls, + world_size: int, + vllm_master_address, + vllm_master_port, + model, +): addr, port = get_ip(), get_open_port() trainer_workers = [] fsdp_world_size = world_size - 1 @@ -346,7 +193,7 @@ def _create_trainer_group(worker_cls, param_server_cls, world_size: int): "MASTER_ADDR": str(addr), "MASTER_PORT": str(port), } - worker = worker_cls.remote(env_vars) + worker = worker_cls.remote(model, env_vars) trainer_workers.append(worker) env_vars = { @@ -355,9 +202,8 @@ def _create_trainer_group(worker_cls, param_server_cls, world_size: int): "MASTER_ADDR": str(addr), "MASTER_PORT": str(port), } - parameter_server = param_server_cls.remote(env_vars) + parameter_server = param_server_cls.remote(model, vllm_master_address, vllm_master_port, env_vars) trainer_workers[0].register_parameter_server.remote(parameter_server) - trainer_workers[1].register_parameter_server.remote(parameter_server) return trainer_workers, parameter_server @@ -370,16 +216,27 @@ def _create_trainer_group(worker_cls, param_server_cls, world_size: int): "memory": 2 * 1024**3, } + model = "facebook/opt-125m" + ray.init(num_cpus=4, num_gpus=4) - trainer_workers, parameter_server = _create_trainer_group(TrainerActor, vLLMParameterServer, 3) + vllm_master_address, vllm_update_port = get_ip(), get_open_port() + + trainer_workers, parameter_server = _create_trainer_group( + TrainerActor, + vLLMParameterServer, + 3, + vllm_master_address, + vllm_update_port, + model, + ) handles = [] for trainer_worker in trainer_workers: handles.append(trainer_worker.train.remote()) - - print(f"param server weights updated {ray.get(parameter_server.check_weights_changed.remote())}") + model_metadata = ray.get(parameter_server.get_model_metadata.remote()) + local_weight_updater = vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata) make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset) collector = RayCollector( @@ -388,7 +245,10 @@ def _create_trainer_group(worker_cls, param_server_cls, world_size: int): frames_per_batch=40, total_frames=200, remote_configs=remote_configs, - remote_weights_updater=parameter_server, + remote_weight_updater=parameter_server, + collector_kwargs={ + "local_weight_updater": local_weight_updater, + }, update_after_each_batch=True, ) print("done collector init") diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 63fb78f96bb..e1e8da39337 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -76,6 +76,15 @@ def cudagraph_mark_step_begin(): """Placeholder for missing cudagraph_mark_step_begin method.""" raise NotImplementedError("cudagraph_mark_step_begin not implemented.") +try: + import ray + from ray.actor import ActorHandle + + _has_ray = True +except ImportError as err: + _has_ray = False + RAY_ERR = err + _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 @@ -174,9 +183,12 @@ def remote_weight_updater(self) -> RemoteWeightUpdaterBase: @remote_weight_updater.setter def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None): if value is not None: - value.register_collector(self) - if value.collector is not self: - raise RuntimeError("Failed to register collector.") + if _has_ray and isinstance(value, ray.actor.ActorHandle): + value.register_collector.remote(self) + else: + value.register_collector(self) + if value.collector is not self: + raise RuntimeError("Failed to register collector.") self._remote_weight_updater = value def _get_policy_and_device( @@ -306,7 +318,7 @@ def update_policy_weights_( """ if self.local_weight_updater is not None: - self.local_weights_updater(policy_weights, **kwargs) + self.local_weight_updater(policy_weights, **kwargs) if self.remote_weight_updater is not None: import ray ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs)) diff --git a/torchrl/collectors/vllm_weight_update.py b/torchrl/collectors/vllm_weight_update.py new file mode 100644 index 00000000000..a6ff63acbf4 --- /dev/null +++ b/torchrl/collectors/vllm_weight_update.py @@ -0,0 +1,179 @@ +import torch +import threading + +from torchrl.collectors.weight_update import RemoteWeightUpdaterBase +from torchrl.collectors.weight_update import LocalWeightUpdaterBase + + +VLLM_ERR = None +try: + import vllm + from vllm.worker.worker import Worker + + _has_vllm = True +except ImportError as err: + _has_vllm = False + VLLM_ERR = err + +# These utilities are copied from vLLM's example code. +def stateless_init_process_group( + master_address: str, + master_port: int, + rank: int, + world_size: int, + device: torch.device, +): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +if _has_vllm: + # I should use worker_extension_cls arg and not inherit from worker, + # but that is only available on main and not vLLM 0.7.3 + class WorkerExtension(Worker): + """ + The class for vLLM's worker to inherit from. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated +else: + class WorkerExtension: + pass + + +class vLLMHFLocalWeightUpdater(LocalWeightUpdaterBase): + def __init__(self, master_address, master_port, model_metadata): + self.master_address = master_address + self.master_port = master_port + self.model_metadata = model_metadata + self.model_update_group = None + + def _get_server_weights(self): + return None + + def _get_local_weights(self): + # We don't implement this because we let vLLM's update_weights API handle everything + return None + + def _maybe_map_weights(self, server_weights, local_weights): + # vLLM update_weights function handles the mapping from huggingface + # so we don't implement this + return None + + def _update_local_weights(self, local_weights, mapped_weights): + llm = self.collector.policy["generate"].module + if self.model_update_group is None: + # FIXME: hardcoded + weight_sync_world_size = 2 + llm.collective_rpc( + "init_weight_update_group", + args=(self.master_address, self.master_port, 1, weight_sync_world_size) + ) + + for k, (dtype, shape) in self.model_metadata.items(): + llm.collective_rpc( + "update_weight", + args=(k, dtype, shape) + ) + +class vLLMRemoteWeightUpdaterBase(RemoteWeightUpdaterBase): + def __init__(self, model, vllm_master_address, vllm_master_port): + super().__init__() + from transformers import AutoModel + self.vllm_master_address = vllm_master_address + self.vllm_master_port = vllm_master_port + self.state_dict = AutoModel.from_pretrained(model).cuda().eval().state_dict() + self.state_dict_lock = threading.Lock() + self.vllm_comm_groups = dict() + # versioning nyi + self.version = 0 + + def acquire_state_dict_lock(self): + self.state_dict_lock.acquire() + + def release_state_dict_lock(self): + self.state_dict_lock.release() + + def get_model_metadata(self): + return {k: (v.dtype, v.shape) for k, v in self.state_dict.items()} + + def all_worker_ids(self): + return [0] + + def _get_server_weights(self): + return self.state_dict + + def _maybe_map_weights(self, server_weights): + return server_weights + + def _init_model_update_group(self, worker_id): + # here again, I want to grab the tp size from the vLLM worker... :( + # llm.llm_engine.parallel_config.tensor_parallel_size + vllm_tp_size = 1 + weight_sync_world_size = vllm_tp_size + 1 + model_update_group = stateless_init_process_group( + self.vllm_master_address, + self.vllm_master_port, + 0, + weight_sync_world_size, + torch.device("cuda:0"), + ) + self.vllm_comm_groups[worker_id] = model_update_group + + def _sync_weights_with_worker( + self, worker_id: int, server_weights + ): + self.collector._remote_collectors[worker_id].update_policy_weights_.remote() + if worker_id not in self.vllm_comm_groups: + self._init_model_update_group(worker_id) + with self.state_dict_lock: + for i, k in enumerate(server_weights.keys()): + self.vllm_comm_groups[worker_id].broadcast(server_weights[k], src=0, stream=torch.cuda.current_stream()) \ No newline at end of file diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index b6232f20a5b..3ea406e1c79 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -150,11 +150,7 @@ def register_collector(self, collector): # noqa self._collector_wr = weakref.ref(collector) @property -<<<<<<< HEAD def collector(self) -> torch.collector.DataCollectorBase: # noqa -======= - def collector(self): ->>>>>>> 3917c0752 (v0 param server (using collectives not object store)) return self._collector_wr() if self._collector_wr is not None else None @abstractmethod From 5c6b015a1f1c283a28caa09166f922a166b02745 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 21 Mar 2025 22:53:01 -0700 Subject: [PATCH 3/5] Update on "v0 param server (using collectives not object store)" [ghstack-poisoned] --- torchrl/collectors/collectors.py | 13 ++++--------- torchrl/collectors/weight_update.py | 10 +++------- torchrl/modules/llm/vllm_policy.py | 7 ------- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e1e8da39337..9b3cff35e1f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -320,8 +320,10 @@ def update_policy_weights_( if self.local_weight_updater is not None: self.local_weight_updater(policy_weights, **kwargs) if self.remote_weight_updater is not None: - import ray - ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs)) + if _has_ray and isinstance(self.remote_weight_updater, ray.actor.ActorHandle): + ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs)) + else: + self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs) elif worker_ids is not None: raise TypeError("worker_ids was passed but remote_weight_updater was None.") @@ -873,13 +875,6 @@ def __init__( self.local_weight_updater = local_weight_updater self.remote_weight_updater = remote_weight_updater - - def call_policy_method(self, method: str, args, kwargs): - # I want world where I don't have to do this to call a method on the - # vllm policy that is owned by the remote collector - - result = getattr(self.policy['generate'].module, method)(*args, **kwargs) - return result @property def _traj_pool(self): diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 3ea406e1c79..46732057ad6 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -13,8 +13,6 @@ from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase -# from torchrl.collectors import DataCollectorBase - Policy = TypeVar("Policy", bound=TensorDictModuleBase) @@ -49,7 +47,7 @@ class LocalWeightUpdaterBase(metaclass=abc.ABCMeta): _collector_wr: Any = None - def register_collector(self, collector): # noqa + def register_collector(self, collector: DataCollectorBase): # noqa """Register a collector in the updater. Once registered, the updater will not accept another collector. @@ -63,7 +61,7 @@ def register_collector(self, collector): # noqa self._collector_wr = weakref.ref(collector) @property - def collector(self): # noqa + def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa return self._collector_wr() if self._collector_wr is not None else None @abstractmethod @@ -136,7 +134,7 @@ class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta): _collector_wr: Any = None - def register_collector(self, collector): # noqa + def register_collector(self, collector: DataCollectorBase): # noqa """Register a collector in the updater. Once registered, the updater will not accept another collector. @@ -186,14 +184,12 @@ def update_weights( weights: TensorDictBase | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, ): - print(f"in update_weights {worker_ids}") if weights is None: # Get the weights on server (local) server_weights = self._get_server_weights() else: server_weights = weights - self._maybe_map_weights(server_weights) # Get the remote weights (inference workers) diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index 688ec7855c6..a810fe98c1e 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -415,13 +415,6 @@ def move_input(td): generate_kwargs.setdefault("logprobs", return_log_probs) sampling_params = SamplingParams(**generate_kwargs) - # def print_weights(td): - # for i, (name, param) in enumerate(model.llm_engine.model_executor.driver_worker.worker.model_runner.model.named_parameters()): - # if i == 0: - # print(f"Model parameters: {name} {param[0]}") - # return td - - # module_dict["print_weights"] = print_weights module_dict["generate"] = Mod( model, method="generate", From 7bd553d6e9909c34cb5866ba421aeb8c9e621e15 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 21 Mar 2025 23:05:12 -0700 Subject: [PATCH 4/5] Update on "v0 param server (using collectives not object store)" [ghstack-poisoned] --- torchrl/collectors/collectors.py | 2 +- torchrl/collectors/distributed/ray.py | 2 -- torchrl/collectors/vllm_weight_update.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9b3cff35e1f..f0a3d198a43 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -875,7 +875,7 @@ def __init__( self.local_weight_updater = local_weight_updater self.remote_weight_updater = remote_weight_updater - + @property def _traj_pool(self): pool = getattr(self, "_traj_pool_val", None) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 25958e31e7f..8dcef99ac5e 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -646,7 +646,6 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): - print(f"{self._sync=}") def proc(data): if self.split_trajs: data = split_trajectories(data) @@ -762,7 +761,6 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: if self.update_after_each_batch or self.max_weight_update_interval > -1: self.update_policy_weights_(worker_ids=collector_index) - print("done updating policy weights") # Schedule a new collection task future = collector.next.remote() pending_tasks[future] = collector_index diff --git a/torchrl/collectors/vllm_weight_update.py b/torchrl/collectors/vllm_weight_update.py index a6ff63acbf4..a1305f65ccf 100644 --- a/torchrl/collectors/vllm_weight_update.py +++ b/torchrl/collectors/vllm_weight_update.py @@ -112,7 +112,7 @@ def _update_local_weights(self, local_weights, mapped_weights): llm = self.collector.policy["generate"].module if self.model_update_group is None: # FIXME: hardcoded - weight_sync_world_size = 2 + weight_sync_world_size = llm.llm_engine.parallel_config.tensor_parallel_size + 1 llm.collective_rpc( "init_weight_update_group", args=(self.master_address, self.master_port, 1, weight_sync_world_size) @@ -146,7 +146,7 @@ def get_model_metadata(self): return {k: (v.dtype, v.shape) for k, v in self.state_dict.items()} def all_worker_ids(self): - return [0] + return [i for i in range(len(self.collector._remote_collectors))] def _get_server_weights(self): return self.state_dict From 0a265a9e3ed09473bca63b7715d8cc46fc4077e9 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 21 Mar 2025 23:48:38 -0700 Subject: [PATCH 5/5] Update on "v0 param server (using collectives not object store)" [ghstack-poisoned] --- param_server_weight_updater.py | 31 ++++++++++++++++-------- torchrl/collectors/vllm_weight_update.py | 20 ++++++++------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/param_server_weight_updater.py b/param_server_weight_updater.py index a627a88f4cd..655bbdb8d8b 100644 --- a/param_server_weight_updater.py +++ b/param_server_weight_updater.py @@ -218,16 +218,18 @@ def _create_trainer_group( model = "facebook/opt-125m" - ray.init(num_cpus=4, num_gpus=4) + ray.init(num_cpus=5, num_gpus=5) - vllm_master_address, vllm_update_port = get_ip(), get_open_port() + vllm_addresses = [get_ip()] * 2 + vllm_ports = [get_open_port() for i in range(2)] + print(vllm_ports) trainer_workers, parameter_server = _create_trainer_group( TrainerActor, vLLMParameterServer, 3, - vllm_master_address, - vllm_update_port, + vllm_addresses, + vllm_ports, model, ) @@ -236,19 +238,28 @@ def _create_trainer_group( handles.append(trainer_worker.train.remote()) model_metadata = ray.get(parameter_server.get_model_metadata.remote()) - local_weight_updater = vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata) + local_weight_updaters = [ + vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata) for + vllm_master_address, vllm_update_port in zip(vllm_addresses, vllm_ports) + ] make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset) collector = RayCollector( - [make_env_parsed], + [make_env_parsed, make_env_parsed], policy_factory=make_policy, frames_per_batch=40, total_frames=200, remote_configs=remote_configs, remote_weight_updater=parameter_server, - collector_kwargs={ - "local_weight_updater": local_weight_updater, - }, + num_collectors=2, + collector_kwargs=[ + { + "local_weight_updater": local_weight_updaters[0], + }, + { + "local_weight_updater": local_weight_updaters[1], + } + ], update_after_each_batch=True, ) print("done collector init") @@ -258,6 +269,6 @@ def _create_trainer_group( for i, data in enumerate(collector): print(tokenizer.decode(data["tokens"][0].squeeze())) print(tokenizer.decode(data["tokens_response"][0].squeeze())) - if i == 1: + if i == 3: break collector.shutdown() \ No newline at end of file diff --git a/torchrl/collectors/vllm_weight_update.py b/torchrl/collectors/vllm_weight_update.py index a1305f65ccf..b80f208a840 100644 --- a/torchrl/collectors/vllm_weight_update.py +++ b/torchrl/collectors/vllm_weight_update.py @@ -56,7 +56,8 @@ class WorkerExtension(Worker): def init_weight_update_group(self, master_address, master_port, rank_offset, world_size): from vllm.distributed.parallel_state import get_world_group - rank = get_world_group().rank + rank_offset + # rank = get_world_group().rank + rank_offset + rank = rank_offset self.model_update_group = stateless_init_process_group( master_address, master_port, @@ -91,10 +92,11 @@ class WorkerExtension: class vLLMHFLocalWeightUpdater(LocalWeightUpdaterBase): def __init__(self, master_address, master_port, model_metadata): + print(f"{master_address=}, {master_port=}") self.master_address = master_address self.master_port = master_port self.model_metadata = model_metadata - self.model_update_group = None + self.initialized_group = None def _get_server_weights(self): return None @@ -110,13 +112,13 @@ def _maybe_map_weights(self, server_weights, local_weights): def _update_local_weights(self, local_weights, mapped_weights): llm = self.collector.policy["generate"].module - if self.model_update_group is None: - # FIXME: hardcoded + if self.initialized_group is None: weight_sync_world_size = llm.llm_engine.parallel_config.tensor_parallel_size + 1 llm.collective_rpc( "init_weight_update_group", args=(self.master_address, self.master_port, 1, weight_sync_world_size) ) + self.initialized_group = True for k, (dtype, shape) in self.model_metadata.items(): llm.collective_rpc( @@ -125,11 +127,11 @@ def _update_local_weights(self, local_weights, mapped_weights): ) class vLLMRemoteWeightUpdaterBase(RemoteWeightUpdaterBase): - def __init__(self, model, vllm_master_address, vllm_master_port): + def __init__(self, model, vllm_master_addresses, vllm_master_ports): super().__init__() from transformers import AutoModel - self.vllm_master_address = vllm_master_address - self.vllm_master_port = vllm_master_port + self.vllm_master_addresses = vllm_master_addresses + self.vllm_master_ports = vllm_master_ports self.state_dict = AutoModel.from_pretrained(model).cuda().eval().state_dict() self.state_dict_lock = threading.Lock() self.vllm_comm_groups = dict() @@ -160,8 +162,8 @@ def _init_model_update_group(self, worker_id): vllm_tp_size = 1 weight_sync_world_size = vllm_tp_size + 1 model_update_group = stateless_init_process_group( - self.vllm_master_address, - self.vllm_master_port, + self.vllm_master_addresses[worker_id], + self.vllm_master_ports[worker_id], 0, weight_sync_world_size, torch.device("cuda:0"),