Skip to content

Commit 49af5da

Browse files
adityagoel4512Adi Goelvmoenssosmond
authored
[Feature] Benchmark storage types (#633)
* Distributed replay buffer prototype * Fixes comment issue * Makes ReplayBufferNode subclass TensorDictReplayBuffer * aha * amend * bf * Fixes print statements and removes redundant Collector arg * Fixes print statements and removes redundant Collector arg * amend * amend * Timing larger tensordict transfers over torch rpc using multiple storage types * init * amend * amend * update_ and make sure content is read * amend * amend * Fixes list storage arg * Moves benchmark to new top-level directory and adds note in documentation about speed up using MemmapTensor * Removes analysis.ipynb * Removes accidental edit to tensordict.py * Updates data.rst text * Removes redundant variable * Removes hack to get list read to work * replace assert_allclose with assert_close (#644) * Adds small note illustrating example usage Co-authored-by: Adi Goel <adityagoel@fb.com> Co-authored-by: vmoens <vincentmoens@gmail.com> Co-authored-by: sosmond <35877775+sosmond@users.noreply.github.com>
1 parent ada0fcd commit 49af5da

File tree

3 files changed

+202
-2
lines changed

3 files changed

+202
-2
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""
2+
Sample latency benchmarking (using RPC)
3+
======================================
4+
A rough benchmark of sample latency using different storage types over the network using `torch.rpc`.
5+
Run this script with --rank=0 and --rank=1 flags set in separate processes - these ranks correspond to the trainer worker and buffer worker respectively, and both need to be initialised.
6+
e.g. to benchmark LazyMemmapStorage, run the following commands using either two separate shells or multiprocessing.
7+
- python3 benchmark_sample_latency_over_rpc.py --rank=0 --storage=LazyMemmapStorage
8+
- python3 benchmark_sample_latency_over_rpc.py --rank=1 --storage=LazyMemmapStorage
9+
This code is based on examples/distributed/distributed_replay_buffer.py.
10+
"""
11+
import argparse
12+
import os
13+
import pickle
14+
import sys
15+
import time
16+
import timeit
17+
from datetime import datetime
18+
19+
import torch
20+
import torch.distributed.rpc as rpc
21+
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
22+
from torchrl.data.replay_buffers.samplers import RandomSampler
23+
from torchrl.data.replay_buffers.storages import (
24+
LazyMemmapStorage,
25+
LazyTensorStorage,
26+
ListStorage,
27+
)
28+
from torchrl.data.replay_buffers.writers import RoundRobinWriter
29+
from torchrl.data.tensordict import TensorDict
30+
31+
RETRY_LIMIT = 2
32+
RETRY_DELAY_SECS = 3
33+
REPLAY_BUFFER_NODE = "ReplayBuffer"
34+
TRAINER_NODE = "Trainer"
35+
TENSOR_SIZE = 3 * 86 * 86
36+
BUFFER_SIZE = 1001
37+
BATCH_SIZE = 256
38+
REPEATS = 1000
39+
40+
storage_options = {
41+
"LazyMemmapStorage": LazyMemmapStorage,
42+
"LazyTensorStorage": LazyTensorStorage,
43+
"ListStorage": ListStorage,
44+
}
45+
46+
storage_arg_options = {
47+
"LazyMemmapStorage": dict(scratch_dir="/tmp/", device=torch.device("cpu")),
48+
"LazyTensorStorage": dict(),
49+
"ListStorage": dict(),
50+
}
51+
parser = argparse.ArgumentParser(
52+
description="RPC Replay Buffer Example",
53+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
54+
)
55+
56+
parser.add_argument(
57+
"--rank",
58+
type=int,
59+
default=-1,
60+
help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]",
61+
)
62+
63+
parser.add_argument(
64+
"--storage",
65+
type=str,
66+
default="LazyMemmapStorage",
67+
help="Storage type [LazyMemmapStorage, LazyTensorStorage, ListStorage]",
68+
)
69+
70+
71+
class DummyTrainerNode:
72+
def __init__(self) -> None:
73+
self.id = rpc.get_worker_info().id
74+
self.replay_buffer = self._create_replay_buffer()
75+
self._ret = None
76+
77+
def train(self, batch_size: int) -> None:
78+
start_time = timeit.default_timer()
79+
ret = rpc.rpc_sync(
80+
self.replay_buffer.owner(),
81+
ReplayBufferNode.sample,
82+
args=(self.replay_buffer, batch_size),
83+
)
84+
if storage_type == "ListStorage":
85+
self._ret = ret[0]
86+
else:
87+
if self._ret is None:
88+
self._ret = ret
89+
else:
90+
self._ret[0].update_(ret[0])
91+
# make sure the content is read
92+
self._ret[0]["observation"] + 1
93+
self._ret[0]["next_observation"] + 1
94+
return timeit.default_timer() - start_time
95+
96+
def _create_replay_buffer(self) -> rpc.RRef:
97+
while True:
98+
try:
99+
replay_buffer_info = rpc.get_worker_info(REPLAY_BUFFER_NODE)
100+
buffer_rref = rpc.remote(
101+
replay_buffer_info, ReplayBufferNode, args=(1000000,)
102+
)
103+
print(f"Connected to replay buffer {replay_buffer_info}")
104+
return buffer_rref
105+
except Exception:
106+
print("Failed to connect to replay buffer")
107+
time.sleep(RETRY_DELAY_SECS)
108+
109+
110+
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
111+
def __init__(self, capacity: int):
112+
super().__init__(
113+
storage=storage_options[storage_type](
114+
max_size=capacity, **storage_arg_options[storage_type]
115+
),
116+
sampler=RandomSampler(),
117+
writer=RoundRobinWriter(),
118+
collate_fn=lambda x: x,
119+
)
120+
tds = TensorDict(
121+
{
122+
"observation": torch.randn(
123+
BUFFER_SIZE,
124+
TENSOR_SIZE,
125+
),
126+
"next_observation": torch.randn(
127+
BUFFER_SIZE,
128+
TENSOR_SIZE,
129+
),
130+
},
131+
batch_size=[BUFFER_SIZE],
132+
)
133+
self.extend(tds)
134+
135+
136+
if __name__ == "__main__":
137+
args = parser.parse_args()
138+
rank = args.rank
139+
storage_type = args.storage
140+
141+
print(f"Rank: {rank}; Storage: {storage_type}")
142+
143+
os.environ["MASTER_ADDR"] = "localhost"
144+
os.environ["MASTER_PORT"] = "29500"
145+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
146+
options = rpc.TensorPipeRpcBackendOptions(
147+
num_worker_threads=16, init_method="tcp://localhost:10002", rpc_timeout=120
148+
)
149+
if rank == 0:
150+
# rank 0 is the trainer
151+
rpc.init_rpc(
152+
TRAINER_NODE,
153+
rank=rank,
154+
backend=rpc.BackendType.TENSORPIPE,
155+
rpc_backend_options=options,
156+
)
157+
trainer = DummyTrainerNode()
158+
results = []
159+
for i in range(REPEATS):
160+
result = trainer.train(batch_size=BATCH_SIZE)
161+
if i == 0:
162+
continue
163+
results.append(result)
164+
print(i, results[-1])
165+
166+
with open(
167+
f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl',
168+
"wb+",
169+
) as f:
170+
pickle.dump(results, f)
171+
172+
tensor_results = torch.tensor(results)
173+
print(f"Mean: {torch.mean(tensor_results)}")
174+
breakpoint()
175+
elif rank == 1:
176+
# rank 1 is the replay buffer
177+
# replay buffer waits passively for construction instructions from trainer node
178+
rpc.init_rpc(
179+
REPLAY_BUFFER_NODE,
180+
rank=rank,
181+
backend=rpc.BackendType.TENSORPIPE,
182+
rpc_backend_options=options,
183+
)
184+
breakpoint()
185+
else:
186+
sys.exit(1)
187+
rpc.shutdown()

docs/source/reference/data.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ We also provide a prototyped composable replay buffer.
4343
torchrl.data.replay_buffers.writers.Writer
4444
torchrl.data.replay_buffers.writers.RoundRobinWriter
4545

46+
Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
47+
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
48+
The following mean sampling latency improvements over using ListStorage were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage.
49+
50+
+-------------------------------+-----------+
51+
| Storage Type | Speed up |
52+
| | |
53+
+===============================+===========+
54+
| :class:`ListStorage` | 1x |
55+
+-------------------------------+-----------+
56+
| :class:`LazyTensorStorage` | 1.83x |
57+
+-------------------------------+-----------+
58+
| :class:`LazyMemmapStorage` | 3.44x |
59+
+-------------------------------+-----------+
4660

4761

4862
TensorDict

torchrl/data/replay_buffers/rb_prototype.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def collate_fn(x):
175175
return stack_td(x, 0, contiguous=True)
176176

177177
kw["collate_fn"] = collate_fn
178-
179178
super().__init__(**kw)
180179
self.priority_key = priority_key
181180

@@ -232,7 +231,7 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
232231
torch.tensor(index, dtype=torch.int, device=stacked_td.device),
233232
inplace=True,
234233
)
235-
self.update_tensordict_priority(tensordicts)
234+
self.update_tensordict_priority(stacked_td)
236235
return index
237236

238237
def update_tensordict_priority(self, data: TensorDictBase) -> None:

0 commit comments

Comments
 (0)