Skip to content

Commit 50af984

Browse files
author
Vincent Moens
committed
[Feature] RayReplayBuffer
ghstack-source-id: 32eff06 Pull Request resolved: #2835
1 parent 9cd95d5 commit 50af984

File tree

11 files changed

+576
-92
lines changed

11 files changed

+576
-92
lines changed

docs/source/reference/data.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ widely used replay buffers:
2020
PrioritizedReplayBuffer
2121
TensorDictReplayBuffer
2222
TensorDictPrioritizedReplayBuffer
23+
RayReplayBuffer
24+
RemoteTensorDictReplayBuffer
2325

2426
Composable Replay Buffers
2527
-------------------------
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Example use of an ever-running, fully async, distributed collector
3+
==================================================================
4+
5+
This example demonstrates how to set up and use a distributed collector
6+
with Ray in a fully asynchronous manner. The collector continuously gathers
7+
data from a gym environment and stores it in a replay buffer, allowing for
8+
concurrent processing and data collection.
9+
10+
Key Components:
11+
1. **Environment Factory**: A simple function that creates instances of the
12+
`GymEnv` environment. In this example, we use the "Pendulum-v1" environment.
13+
2. **Policy Definition**: A `TensorDictModule` that defines the policy network.
14+
Here, a simple linear layer is used to map observations to actions.
15+
3. **Replay Buffer**: A `RayReplayBuffer` that stores collected data for later
16+
use, such as training a reinforcement learning model.
17+
4. **Distributed Collector**: A `RayCollector` that manages the distributed
18+
collection of data. It is configured with remote resources and interacts
19+
with the environment and policy to gather data.
20+
5. **Asynchronous Execution**: The collector runs in the background, allowing
21+
the main program to perform other tasks concurrently. The example includes
22+
a loop that waits for data to be available in the buffer and samples it.
23+
6. **Graceful Shutdown**: The collector is shut down asynchronously, ensuring
24+
that all resources are properly released.
25+
26+
This setup is useful for scenarios where you need to collect data from
27+
multiple environments in parallel, leveraging Ray's distributed computing
28+
capabilities to scale efficiently.
29+
30+
"""
31+
import asyncio
32+
33+
from tensordict.nn import TensorDictModule
34+
from torch import nn
35+
from torchrl.collectors.distributed.ray import RayCollector
36+
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
37+
from torchrl.envs.libs.gym import GymEnv
38+
39+
40+
async def main():
41+
# 1. Create environment factory
42+
def env_maker():
43+
return GymEnv("Pendulum-v1", device="cpu")
44+
45+
policy = TensorDictModule(
46+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
47+
)
48+
49+
buffer = RayReplayBuffer()
50+
51+
# 2. Define distributed collector
52+
remote_config = {
53+
"num_cpus": 1,
54+
"num_gpus": 0,
55+
"memory": 5 * 1024**3,
56+
"object_store_memory": 2 * 1024**3,
57+
}
58+
distributed_collector = RayCollector(
59+
[env_maker],
60+
policy,
61+
total_frames=600,
62+
frames_per_batch=200,
63+
remote_configs=remote_config,
64+
replay_buffer=buffer,
65+
)
66+
67+
print("start")
68+
distributed_collector.start()
69+
70+
while True:
71+
while not len(buffer):
72+
print("waiting")
73+
await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep
74+
print("sample", buffer.sample(32))
75+
# break at some point
76+
break
77+
78+
await distributed_collector.async_shutdown()
79+
80+
81+
if __name__ == "__main__":
82+
asyncio.run(main())

examples/distributed/collectors/multi_nodes/ray_collect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def env_maker():
2727
# 2. Define distributed collector
2828
remote_config = {
2929
"num_cpus": 1,
30-
"num_gpus": 0.2,
30+
"num_gpus": 0,
3131
"memory": 5 * 1024**3,
3232
"object_store_memory": 2 * 1024**3,
3333
}
@@ -36,6 +36,7 @@ def env_maker():
3636
policy,
3737
total_frames=10000,
3838
frames_per_batch=200,
39+
remote_configs=remote_config,
3940
)
4041

4142
# Sample batches until reaching total_frames

test/test_distributed.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,19 @@
1111
import os
1212
import sys
1313
import time
14+
from functools import partial
1415

1516
import pytest
17+
from tensordict import TensorDict
1618
from tensordict.nn import TensorDictModuleBase
1719
from torchrl._utils import logger as torchrl_logger
20+
from torchrl.data import (
21+
LazyTensorStorage,
22+
RandomSampler,
23+
RayReplayBuffer,
24+
RoundRobinWriter,
25+
SamplerWithoutReplacement,
26+
)
1827

1928
try:
2029
import ray
@@ -435,6 +444,15 @@ class TestRayCollector(DistributedCollectorBase):
435444
to avoid potential deadlocks when combining Ray and multiprocessing.
436445
"""
437446

447+
@pytest.fixture(autouse=True, scope="class")
448+
def start_ray(self):
449+
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
450+
451+
ray.init(**DEFAULT_RAY_INIT_CONFIG)
452+
453+
yield
454+
ray.shutdown()
455+
438456
@classmethod
439457
def distributed_class(cls) -> type:
440458
return RayCollector
@@ -552,6 +570,29 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
552570
collector.shutdown()
553571
assert total == total_frames
554572

573+
@pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)])
574+
@pytest.mark.parametrize(
575+
"sampler", [None, partial(RandomSampler), SamplerWithoutReplacement]
576+
)
577+
@pytest.mark.parametrize("writer", [None, partial(RoundRobinWriter)])
578+
def test_ray_replaybuffer(self, storage, sampler, writer):
579+
kwargs = self.distributed_kwargs()
580+
kwargs["remote_config"] = kwargs.pop("remote_configs")
581+
rb = RayReplayBuffer(
582+
storage=storage,
583+
sampler=sampler,
584+
writer=writer,
585+
batch_size=32,
586+
**kwargs,
587+
)
588+
td = TensorDict(a=torch.arange(100, 200), batch_size=[100])
589+
index = rb.extend(td)
590+
assert (index == torch.arange(100)).all()
591+
for _ in range(10):
592+
sample = rb.sample()
593+
if sampler is SamplerWithoutReplacement:
594+
assert sample["a"].unique().numel() == sample.numel()
595+
555596

556597
if __name__ == "__main__":
557598
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/collectors/collectors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def next(self):
267267
self._iterator = iter(self)
268268
out = next(self._iterator)
269269
# if any, we don't want the device ref to be passed in distributed settings
270-
out.clear_device_()
270+
if out is not None:
271+
out.clear_device_()
271272
return out
272273
except StopIteration:
273274
return None
@@ -432,7 +433,7 @@ class SyncDataCollector(DataCollectorBase):
432433
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
433434
This isn't compatible with environments with dynamic specs. Defaults to ``True``
434435
for envs without dynamic specs, ``False`` for others.
435-
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
436+
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
436437
but populate the buffer instead. Defaults to ``None``.
437438
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
438439
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
@@ -1430,7 +1431,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
14301431
def __repr__(self) -> str:
14311432
env_str = indent(f"env={self.env}", 4 * " ")
14321433
policy_str = indent(f"policy={self.policy}", 4 * " ")
1433-
td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ")
1434+
td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ")
14341435
string = (
14351436
f"{self.__class__.__name__}("
14361437
f"\n{env_str},"
@@ -1586,7 +1587,7 @@ class _MultiDataCollector(DataCollectorBase):
15861587
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
15871588
This isn't compatible with environments with dynamic specs. Defaults to ``True``
15881589
for envs without dynamic specs, ``False`` for others.
1589-
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
1590+
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
15901591
but populate the buffer instead. Defaults to ``None``.
15911592
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
15921593
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules

torchrl/collectors/distributed/ray.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import asyncio
89
import warnings
910
from typing import Callable, Iterator, OrderedDict
1011

@@ -21,6 +22,7 @@
2122
SyncDataCollector,
2223
)
2324
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
25+
from torchrl.data import ReplayBuffer
2426
from torchrl.envs.common import EnvBase
2527
from torchrl.envs.env_creator import EnvCreator
2628

@@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase):
256258
parameters being updated for a certain time even if ``update_after_each_batch``
257259
is turned on.
258260
Defaults to -1 (no forced update).
261+
replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts
262+
but populate the buffer instead. Defaults to ``None``.
263+
264+
.. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
265+
:class:`~torchrl.data.RayReplayBuffer` instance should be used here.
259266
260267
Examples:
261268
>>> from torch import nn
@@ -312,7 +319,9 @@ def __init__(
312319
num_collectors: int = None,
313320
update_after_each_batch=False,
314321
max_weight_update_interval=-1,
322+
replay_buffer: ReplayBuffer = None,
315323
):
324+
self.frames_per_batch = frames_per_batch
316325
if remote_configs is None:
317326
remote_configs = DEFAULT_REMOTE_CLASS_CONFIG
318327

@@ -321,6 +330,14 @@ def __init__(
321330

322331
if collector_kwargs is None:
323332
collector_kwargs = {}
333+
if replay_buffer is not None:
334+
if isinstance(collector_kwargs, dict):
335+
collector_kwargs.setdefault("replay_buffer", replay_buffer)
336+
else:
337+
collector_kwargs = [
338+
ck.setdefault("replay_buffer", replay_buffer)
339+
for ck in collector_kwargs
340+
]
324341

325342
# Make sure input parameters are consistent
326343
def check_consistency_with_num_collectors(param, param_name, num_collectors):
@@ -386,7 +403,8 @@ def check_list_length_consistency(*lists):
386403
raise RuntimeError(
387404
"ray library not found, unable to create a DistributedCollector. "
388405
) from RAY_ERR
389-
ray.init(**ray_init_config)
406+
if not ray.is_initialized():
407+
ray.init(**ray_init_config)
390408
if not ray.is_initialized():
391409
raise RuntimeError("Ray could not be initialized.")
392410

@@ -400,6 +418,7 @@ def check_list_length_consistency(*lists):
400418
collector_class.as_remote = as_remote
401419
collector_class.print_remote_collector_info = print_remote_collector_info
402420

421+
self.replay_buffer = replay_buffer
403422
self._local_policy = policy
404423
if isinstance(self._local_policy, nn.Module):
405424
policy_weights = TensorDict.from_module(self._local_policy)
@@ -557,7 +576,7 @@ def add_collectors(
557576
policy,
558577
other_params,
559578
)
560-
self._remote_collectors.extend([collector])
579+
self._remote_collectors.append(collector)
561580

562581
def local_policy(self):
563582
"""Returns local collector."""
@@ -577,17 +596,33 @@ def stop_remote_collectors(self):
577596
) # This will interrupt any running tasks on the actor, causing them to fail immediately
578597

579598
def iterator(self):
599+
def proc(data):
600+
if self.split_trajs:
601+
data = split_trajectories(data)
602+
if self.postproc is not None:
603+
data = self.postproc(data)
604+
return data
605+
580606
if self._sync:
581-
data = self._sync_iterator()
607+
meth = self._sync_iterator
582608
else:
583-
data = self._async_iterator()
609+
meth = self._async_iterator
610+
yield from (proc(data) for data in meth())
584611

585-
if self.split_trajs:
586-
data = split_trajectories(data)
587-
if self.postproc is not None:
588-
data = self.postproc(data)
612+
async def _asyncio_iterator(self):
613+
def proc(data):
614+
if self.split_trajs:
615+
data = split_trajectories(data)
616+
if self.postproc is not None:
617+
data = self.postproc(data)
618+
return data
589619

590-
return data
620+
if self._sync:
621+
for d in self._sync_iterator():
622+
yield proc(d)
623+
else:
624+
for d in self._async_iterator():
625+
yield proc(d)
591626

592627
def _sync_iterator(self) -> Iterator[TensorDictBase]:
593628
"""Collects one data batch per remote collector in each iteration."""
@@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
634669
):
635670
self.update_policy_weights_(rank)
636671

637-
self.shutdown()
672+
if self._task is None:
673+
self.shutdown()
674+
675+
_task = None
676+
677+
def start(self):
678+
"""Starts the RayCollector."""
679+
if self.replay_buffer is None:
680+
raise RuntimeError("Replay buffer must be defined for asyncio execution.")
681+
if self._task is None or self._task.done():
682+
loop = asyncio.get_event_loop()
683+
self._task = loop.create_task(self._run_iterator_silently())
684+
685+
async def _run_iterator_silently(self):
686+
async for _ in self._asyncio_iterator():
687+
# Process each item silently
688+
continue
689+
690+
async def async_shutdown(self):
691+
"""Finishes processes started by ray.init() during async execution."""
692+
if self._task is not None:
693+
await self._task
694+
self.stop_remote_collectors()
695+
ray.shutdown()
638696

639697
def _async_iterator(self) -> Iterator[TensorDictBase]:
640698
"""Collects a data batch from a single remote collector in each iteration."""
@@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
658716
ray.internal.free(
659717
[future]
660718
) # should not be necessary, deleted automatically when ref count is down to 0
661-
self.collected_frames += out_td.numel()
719+
self.collected_frames += self.frames_per_batch
662720

663721
yield out_td
664722

@@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
689747
# object_ref=ref,
690748
# force=False,
691749
# )
692-
693-
self.shutdown()
750+
if self._task is None:
751+
self.shutdown()
694752

695753
def update_policy_weights_(self, worker_rank=None) -> None:
696754
"""Updates the weights of the worker nodes.

0 commit comments

Comments
 (0)