Skip to content

Commit da7904e

Browse files
author
Vincent Moens
authored
[Feature] PyTrees in replay buffers (#1831)
1 parent 24d14ad commit da7904e

File tree

11 files changed

+985
-324
lines changed

11 files changed

+985
-324
lines changed

docs/source/reference/data.rst

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,108 @@ widely used replay buffers:
2222
Composable Replay Buffers
2323
-------------------------
2424

25-
We also give users the ability to compose a replay buffer using the following components:
25+
We also give users the ability to compose a replay buffer.
26+
We provide a wide panel of solutions for replay buffer usage, including support for
27+
almost any data type; storage in memory, on device or on physical memory;
28+
several sampling strategies; usage of transforms etc.
29+
30+
Supported data types and choosing a storage
31+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32+
33+
In theory, replay buffers support any data type but we can't guarantee that each
34+
component will support any data type. The most crude replay buffer implementation
35+
is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a
36+
:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient
37+
but it will allow you to store complex data structures with non-tensor data.
38+
Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`,
39+
:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and
40+
:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`.
41+
These classes support :class:`~tensordict.TensorDict` data as first-class citizens, but also
42+
any PyTree data structure (eg, tuples, lists, dictionaries and nested versions
43+
of these). The :class:`~torchrl.data.replay_buffers.TensorStorage` storage requires
44+
you to provide the storage at construction time, whereas :class:`~torchrl.data.replay_buffers.TensorStorage`
45+
(RAM, CUDA) and :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` (physical memory)
46+
will preallocate the storage for you after they've been extended the first time.
47+
48+
Here are a few examples, starting with the generic :class:`~torchrl.data.replay_buffers.ListStorage`:
49+
50+
>>> from torchrl.data.replay_buffers import ReplayBuffer, ListStorage
51+
>>> rb = ReplayBuffer(storage=ListStorage(10))
52+
>>> rb.add("a string!") # first element will be a string
53+
>>> rb.extend([30, None]) # element [1] is an int, [2] is None
54+
55+
Using a :class:`~torchrl.data.replay_buffers.TensorStorage` we tell our RB that
56+
we want the storage to be contiguous, which is by far more efficient but also
57+
more restrictive:
58+
59+
>>> import torch
60+
>>> from torchrl.data.replay_buffers import ReplayBuffer, TensorStorage
61+
>>> container = torch.empty(10, 3, 64, 64, dtype=torch.unit8)
62+
>>> rb = ReplayBuffer(storage=TensorStorage(container))
63+
>>> img = torch.randint(255, (3, 64, 64), dtype=torch.uint8)
64+
>>> rb.add(img)
65+
66+
Next we can avoid creating the container and ask the storage to do it automatically.
67+
This is very useful when using PyTrees and tensordicts! For PyTrees as other data
68+
structures, :meth:`~torchrl.data.replay_buffers.ReplayBuffer.add` considers the sampled
69+
passed to it as a single instance of the type. :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend`
70+
on the other hand will consider that the data is an iterable. For tensors, tensordicts
71+
and lists (see below), the iterable is looked for at the root level. For PyTrees,
72+
we assume that the leading dimension of all the leaves (tensors) in the tree
73+
match. If they don't, ``extend`` will throw an exception.
74+
75+
>>> import torch
76+
>>> from tensordict import TensorDict
77+
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyMemmapStorage
78+
>>> rb_td = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=1) # max 10 elements stored
79+
>>> rb_td.add(TensorDict({"img": torch.randint(255, (3, 64, 64), dtype=torch.unit8),
80+
... "labels": torch.randint(100, ())}, batch_size=[]))
81+
>>> rb_pytree = ReplayBuffer(storage=LazyMemmapStorage(10)) # max 10 elements stored
82+
>>> # extend with a PyTree where all tensors have the same leading dim (3)
83+
>>> rb_pytree.extend({"a": {"b": torch.randn(3), "c": [torch.zeros(3, 2), (torch.ones(3, 10),)]}})
84+
>>> assert len(rb_pytree) == 3 # the replay buffer has 3 elements!
85+
86+
.. note:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an
87+
ambiguous signature when dealing with lists of values, which should be interpreted
88+
either as PyTree (in which case all elements in the list will be put in a slice
89+
in the stored PyTree in the storage) or a list of values to add one at a time.
90+
To solve this, TorchRL makes the clear-cut distinction between list and tuple:
91+
a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted
92+
as a stack of values to add one at a time to the buffer.
93+
94+
Sampling and indexing
95+
~~~~~~~~~~~~~~~~~~~~~
96+
97+
Replay buffers can be indexed and sampled.
98+
Indexing and sampling collect data at given indices in the storage and then process them
99+
through a series of transforms and ``collate_fn`` that can be passed to the `__init__`
100+
function of the replay buffer. ``collate_fn`` comes with default values that should
101+
match user expectations in the majority of cases, such that you should not have
102+
to worry about it most of the time. Transforms are usually instances of :class:`~torchrl.envs.transforms.Transform`
103+
even though regular functions will work too (in the latter case, the :meth:`~torchrl.envs.transforms.Transform.inv`
104+
method will obviously be ignored, whereas in the first case it can be used to
105+
preprocess the data before it is passed to the buffer).
106+
Finally, sampling can be achieved using multithreading by passing the number of threads
107+
to the constructor through the ``prefetch`` keyword argument. We advise users to
108+
benchmark this technique in real life settings before adopting it, as there is
109+
no guarantee that it will lead to a faster throughput in practice depending on
110+
the machine and setting where it is used.
111+
112+
When sampling, the ``batch_size`` can be either passed during construction
113+
(e.g., if it's constant throughout training) or
114+
to the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method.
115+
116+
To further refine the sampling strategy, we advise you to look into our samplers!
117+
118+
Here are a couple of examples of how to get data out of a replay buffer:
119+
120+
>>> first_elt = rb_td[0]
121+
>>> storage = rb_td[:] # returns all valid elements from the buffer
122+
>>> sample = rb_td.sample(128)
123+
>>> for data in rb_td: # iterate over the buffer using the sampler -- batch-size was set in the constructor to 1
124+
... print(data)
125+
126+
using the following components:
26127

27128
.. currentmodule:: torchrl.data.replay_buffers
28129

@@ -48,9 +149,14 @@ We also give users the ability to compose a replay buffer using the following co
48149
TensorDictRoundRobinWriter
49150
TensorDictMaxValueWriter
50151

51-
Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
52-
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
53-
The following mean sampling latency improvements over using ListStorage were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage.
152+
Storage choice is very influential on replay buffer sampling latency, especially
153+
in distributed reinforcement learning settings with larger data volumes.
154+
:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly
155+
advised in distributed settings with shared storage due to the lower serialisation
156+
cost of MemoryMappedTensors as well as the ability to specify file storage locations
157+
for improved node failure recovery.
158+
The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage`
159+
were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage.
54160

55161
+-------------------------------+-----------+
56162
| Storage Type | Speed up |

test/test_cost.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import functools
99
import itertools
1010
import operator
11-
import re
1211
import warnings
1312
from copy import deepcopy
1413
from dataclasses import asdict, dataclass
@@ -270,14 +269,11 @@ def forward(self, td):
270269
loss_module.set_vmap_randomness(vmap_randomness)
271270
# Fail case
272271
elif vmap_randomness == "error" and dropout > 0.0:
273-
with pytest.raises(RuntimeError) as exc_info:
272+
with pytest.raises(
273+
RuntimeError,
274+
match="vmap: called random operation while in randomness error mode",
275+
):
274276
loss_module(td)["loss"]
275-
276-
# Accessing cause of the caught exception
277-
cause = exc_info.value.__cause__
278-
assert re.match(
279-
r"vmap: called random operation while in randomness error mode", str(cause)
280-
)
281277
return
282278
loss_module(td)["loss"]
283279

@@ -1238,7 +1234,7 @@ def test_mixer_keys(
12381234

12391235
# Wthout etting the keys
12401236
if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"):
1241-
with pytest.raises(RuntimeError):
1237+
with pytest.raises(KeyError):
12421238
loss(td)
12431239
elif unravel_key(mixer_global_chosen_action_value_key) != "chosen_action_value":
12441240
with pytest.raises(
@@ -1253,7 +1249,7 @@ def test_mixer_keys(
12531249
loss.set_keys(global_value=mixer_global_chosen_action_value_key)
12541250
if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"):
12551251
with pytest.raises(
1256-
RuntimeError
1252+
KeyError
12571253
): # The mixer in key still does not match the actor out_key
12581254
loss(td)
12591255
else:

0 commit comments

Comments
 (0)