Skip to content

Commit b4b27fe

Browse files
adityagoel4512Adi Goelvmoens
authored
[Example] Distributed Replay Buffer Prototype Example Implementation (#615)
* Distributed replay buffer prototype * Fixes comment issue * Makes ReplayBufferNode subclass TensorDictReplayBuffer * aha * amend * bf * Fixes print statements and removes redundant Collector arg * Fixes print statements and removes redundant Collector arg * amend * amend * Adds class decorator * AddsRemoteTensorDictReplayBuffer to rb_prototype.py * Adds RemoteTensorDictReplayBuffer to docs * Adds docstring comments to distributed replay buffer example * Adds docstring comments to distributed replay buffer example * Adds RemoteTensorDictReplayBuffer to existing test fixture * Adds distributed rb test suite * Moves rpc init and shutdown outside scope of test function * Remove stray print and add more descriptive error if unable to connect to buffer * Remove stray print and add more descriptive error if unable to connect to buffer Co-authored-by: Adi Goel <adityagoel@fb.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 49039d1 commit b4b27fe

File tree

6 files changed

+409
-7
lines changed

6 files changed

+409
-7
lines changed

docs/source/reference/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ We also provide a prototyped composable replay buffer.
3232

3333
torchrl.data.replay_buffers.rb_prototype.ReplayBuffer
3434
torchrl.data.replay_buffers.rb_prototype.TensorDictReplayBuffer
35+
torchrl.data.replay_buffers.rb_prototype.RemoteTensorDictReplayBuffer
3536
torchrl.data.replay_buffers.samplers.Sampler
3637
torchrl.data.replay_buffers.samplers.RandomSampler
3738
torchrl.data.replay_buffers.samplers.PrioritizedSampler
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
Example use of a distributed replay buffer
3+
===========================
4+
5+
This example illustrates how a skeleton reinforcement learning algorithm can be implemented in a distributed fashion with communication between nodes/workers handled using `torch.rpc`.
6+
It focusses on how to set up a replay buffer worker that accepts remote operation requests efficiently, and so omits any learning component such as parameter updates that may be required for a complete distributed reinforcement learning algorithm implementation.
7+
In this model, >= 1 data collectors workers are responsible for collecting experiences in an environment, the replay buffer worker receives all of these experiences and exposes them to a trainer that is responsible for making parameter updates to any required models.
8+
"""
9+
10+
import argparse
11+
import os
12+
import random
13+
import sys
14+
import time
15+
16+
import torch
17+
import torch.distributed.rpc as rpc
18+
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
19+
from torchrl.data.replay_buffers.samplers import RandomSampler
20+
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
21+
from torchrl.data.replay_buffers.utils import accept_remote_rref_invocation
22+
from torchrl.data.replay_buffers.writers import RoundRobinWriter
23+
from torchrl.data.tensordict import TensorDict
24+
25+
RETRY_LIMIT = 2
26+
RETRY_DELAY_SECS = 3
27+
REPLAY_BUFFER_NODE = "ReplayBuffer"
28+
TRAINER_NODE = "Trainer"
29+
30+
parser = argparse.ArgumentParser(
31+
description="RPC Replay Buffer Example",
32+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33+
)
34+
35+
parser.add_argument(
36+
"--rank",
37+
type=int,
38+
default=-1,
39+
help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]",
40+
)
41+
42+
43+
class DummyDataCollectorNode:
44+
"""Data collector node responsible for collecting experiences used for learning.
45+
46+
Args:
47+
replay_buffer (rpc.RRef): the RRef associated with the construction of the replay buffer
48+
"""
49+
50+
def __init__(self, replay_buffer: rpc.RRef) -> None:
51+
self.id = rpc.get_worker_info().id
52+
self.replay_buffer = replay_buffer
53+
print("Data Collector Node constructed")
54+
55+
def _submit_random_item_async(self) -> rpc.RRef:
56+
td = TensorDict({"a": torch.randint(100, (1,))}, [])
57+
return rpc.remote(
58+
self.replay_buffer.owner(),
59+
ReplayBufferNode.add,
60+
args=(
61+
self.replay_buffer,
62+
td,
63+
),
64+
)
65+
66+
@accept_remote_rref_invocation
67+
def collect(self):
68+
"""Method that begins experience collection (we just generate random TensorDicts in this example). `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation `rpc.RRef` is provided in place of the object reference."""
69+
for elem in range(50):
70+
time.sleep(random.randint(1, 4))
71+
print(
72+
f"Collector [{self.id}] submission {elem}: {self._submit_random_item_async().to_here()}"
73+
)
74+
75+
76+
class DummyTrainerNode:
77+
"""Trainer node responsible for learning from experiences sampled from an experience replay buffer."""
78+
79+
def __init__(self) -> None:
80+
print("DummyTrainerNode")
81+
self.id = rpc.get_worker_info().id
82+
self.replay_buffer = self._create_replay_buffer()
83+
self._create_and_launch_data_collectors()
84+
85+
def train(self, iterations: int) -> None:
86+
for iteration in range(iterations):
87+
print(f"[{self.id}] Training Iteration: {iteration}")
88+
time.sleep(3)
89+
batch = rpc.rpc_sync(
90+
self.replay_buffer.owner(),
91+
ReplayBufferNode.sample,
92+
args=(self.replay_buffer, 16),
93+
)
94+
print(f"[{self.id}] Sample Obtained Iteration: {iteration}")
95+
print(f"{batch}")
96+
97+
def _create_replay_buffer(self) -> rpc.RRef:
98+
while True:
99+
try:
100+
replay_buffer_info = rpc.get_worker_info(REPLAY_BUFFER_NODE)
101+
buffer_rref = rpc.remote(
102+
replay_buffer_info, ReplayBufferNode, args=(10000,)
103+
)
104+
print(f"Connected to replay buffer {replay_buffer_info}")
105+
return buffer_rref
106+
except Exception as e:
107+
print(f"Failed to connect to replay buffer: {e}")
108+
time.sleep(RETRY_DELAY_SECS)
109+
110+
def _create_and_launch_data_collectors(self) -> None:
111+
data_collector_number = 2
112+
retries = 0
113+
data_collectors = []
114+
data_collector_infos = []
115+
# discover launched data collector nodes (with retry to allow collectors to dynamically join)
116+
while True:
117+
try:
118+
data_collector_info = rpc.get_worker_info(
119+
f"DataCollector{data_collector_number}"
120+
)
121+
print(f"Data collector info: {data_collector_info}")
122+
dc_ref = rpc.remote(
123+
data_collector_info,
124+
DummyDataCollectorNode,
125+
args=(self.replay_buffer,),
126+
)
127+
data_collectors.append(dc_ref)
128+
data_collector_infos.append(data_collector_info)
129+
data_collector_number += 1
130+
retries = 0
131+
except Exception:
132+
retries += 1
133+
print(
134+
f"Failed to connect to DataCollector{data_collector_number} with {retries} retries"
135+
)
136+
if retries >= RETRY_LIMIT:
137+
print(f"{len(data_collectors)} data collectors")
138+
for data_collector_info, data_collector in zip(
139+
data_collector_infos, data_collectors
140+
):
141+
rpc.remote(
142+
data_collector_info,
143+
DummyDataCollectorNode.collect,
144+
args=(data_collector,),
145+
)
146+
break
147+
else:
148+
time.sleep(RETRY_DELAY_SECS)
149+
150+
151+
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
152+
"""Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` means all of it's public methods are remotely invokable using `torch.rpc`.
153+
Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures.
154+
155+
Args:
156+
capacity (int): the maximum number of elements that can be stored in the replay buffer.
157+
"""
158+
159+
def __init__(self, capacity: int):
160+
super().__init__(
161+
storage=LazyMemmapStorage(
162+
max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu")
163+
),
164+
sampler=RandomSampler(),
165+
writer=RoundRobinWriter(),
166+
collate_fn=lambda x: x,
167+
)
168+
169+
170+
if __name__ == "__main__":
171+
args = parser.parse_args()
172+
rank = args.rank
173+
print(f"Rank: {rank}")
174+
175+
os.environ["MASTER_ADDR"] = "localhost"
176+
os.environ["MASTER_PORT"] = "29500"
177+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
178+
str_init_method = "tcp://localhost:10000"
179+
options = rpc.TensorPipeRpcBackendOptions(
180+
num_worker_threads=16, init_method=str_init_method
181+
)
182+
if rank == 0:
183+
# rank 0 is the trainer
184+
rpc.init_rpc(
185+
TRAINER_NODE,
186+
rank=rank,
187+
backend=rpc.BackendType.TENSORPIPE,
188+
rpc_backend_options=options,
189+
)
190+
print(f"Initialised Trainer Node {rank}")
191+
trainer = DummyTrainerNode()
192+
trainer.train(100)
193+
breakpoint()
194+
elif rank == 1:
195+
# rank 1 is the replay buffer
196+
# replay buffer waits passively for construction instructions from trainer node
197+
print(REPLAY_BUFFER_NODE)
198+
rpc.init_rpc(
199+
REPLAY_BUFFER_NODE,
200+
rank=rank,
201+
backend=rpc.BackendType.TENSORPIPE,
202+
rpc_backend_options=options,
203+
)
204+
print(f"Initialised RB Node {rank}")
205+
breakpoint()
206+
elif rank >= 2:
207+
# rank 2+ is a new data collector node
208+
# data collectors also wait passively for construction instructions from trainer node
209+
rpc.init_rpc(
210+
f"DataCollector{rank}",
211+
rank=rank,
212+
backend=rpc.BackendType.TENSORPIPE,
213+
rpc_backend_options=options,
214+
)
215+
print(f"Initialised DC Node {rank}")
216+
breakpoint()
217+
else:
218+
sys.exit(1)
219+
rpc.shutdown()

test/test_rb.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@
4040

4141

4242
@pytest.mark.parametrize(
43-
"rb_type", [rb_prototype.ReplayBuffer, rb_prototype.TensorDictReplayBuffer]
43+
"rb_type",
44+
[
45+
rb_prototype.ReplayBuffer,
46+
rb_prototype.TensorDictReplayBuffer,
47+
rb_prototype.RemoteTensorDictReplayBuffer,
48+
],
4449
)
4550
@pytest.mark.parametrize(
4651
"sampler", [samplers.RandomSampler, samplers.PrioritizedSampler]
@@ -69,16 +74,22 @@ def _get_rb(self, rb_type, size, sampler, writer, storage):
6974
def _get_datum(self, rb_type):
7075
if rb_type is rb_prototype.ReplayBuffer:
7176
data = torch.randint(100, (1,))
72-
elif rb_type is rb_prototype.TensorDictReplayBuffer:
77+
elif (
78+
rb_type is rb_prototype.TensorDictReplayBuffer
79+
or rb_type is rb_prototype.RemoteTensorDictReplayBuffer
80+
):
7381
data = TensorDict({"a": torch.randint(100, (1,))}, [])
7482
else:
7583
raise NotImplementedError(rb_type)
7684
return data
7785

78-
def _get_data(self, rbtype, size):
79-
if rbtype is rb_prototype.ReplayBuffer:
86+
def _get_data(self, rb_type, size):
87+
if rb_type is rb_prototype.ReplayBuffer:
8088
data = torch.randint(100, (size, 1))
81-
elif rbtype is rb_prototype.TensorDictReplayBuffer:
89+
elif (
90+
rb_type is rb_prototype.TensorDictReplayBuffer
91+
or rb_type is rb_prototype.RemoteTensorDictReplayBuffer
92+
):
8293
data = TensorDict(
8394
{
8495
"a": torch.randint(100, (size,)),
@@ -87,7 +98,7 @@ def _get_data(self, rbtype, size):
8798
[size],
8899
)
89100
else:
90-
raise NotImplementedError(rbtype)
101+
raise NotImplementedError(rb_type)
91102
return data
92103

93104
def test_add(self, rb_type, sampler, writer, storage, size):

0 commit comments

Comments
 (0)