Skip to content

Commit fe3f00c

Browse files
author
Vincent Moens
committed
[Feature] LazyStackStorage
ghstack-source-id: e9c0314 Pull Request resolved: #2723
1 parent 280297a commit fe3f00c

File tree

6 files changed

+113
-1
lines changed

6 files changed

+113
-1
lines changed

docs/source/reference/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ using the following components:
148148
LazyMemmapStorage
149149
LazyTensorStorage
150150
ListStorage
151+
LazyStackStorage
151152
ListStorageCheckpointer
152153
NestedStorageCheckpointer
153154
PrioritizedSampler

test/test_rb.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181

8282
from torchrl.data.replay_buffers.storages import (
8383
LazyMemmapStorage,
84+
LazyStackStorage,
8485
LazyTensorStorage,
8586
ListStorage,
8687
StorageEnsemble,
@@ -1116,6 +1117,31 @@ def test_storage_inplace_writing_ndim(self, storage_type):
11161117
assert (rb[:, 10:20] == 0).all()
11171118
assert len(rb) == 100
11181119

1120+
@pytest.mark.parametrize("max_size", [1000, None])
1121+
@pytest.mark.parametrize("stack_dim", [-1, 0])
1122+
def test_lazy_stack_storage(self, max_size, stack_dim):
1123+
# Create an instance of LazyStackStorage with given parameters
1124+
storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim)
1125+
# Create a ReplayBuffer using the created storage
1126+
rb = ReplayBuffer(storage=storage)
1127+
# Generate some random data to add to the buffer
1128+
torch.manual_seed(0)
1129+
data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
1130+
data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
1131+
# Add the data to the buffer
1132+
rb.add(data0)
1133+
rb.add(data1)
1134+
# Sample from the buffer
1135+
sample = rb.sample(10)
1136+
# Check that the sampled data has the correct shape and type
1137+
assert isinstance(sample, LazyStackedTensorDict)
1138+
assert sample["b"].shape[0] == 10
1139+
assert all(isinstance(item, str) for item in sample["c"])
1140+
# If densify is True, check that the sampled data is dense
1141+
sample = sample.densify(layout=torch.jagged)
1142+
assert isinstance(sample["a"], torch.Tensor)
1143+
assert sample["a"].shape[0] == 10
1144+
11191145

11201146
@pytest.mark.parametrize("max_size", [1000])
11211147
@pytest.mark.parametrize("shape", [[3, 4]])

torchrl/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
H5StorageCheckpointer,
2424
ImmutableDatasetWriter,
2525
LazyMemmapStorage,
26+
LazyStackStorage,
2627
LazyTensorStorage,
2728
ListStorage,
2829
ListStorageCheckpointer,

torchrl/data/replay_buffers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from .storages import (
3434
LazyMemmapStorage,
35+
LazyStackStorage,
3536
LazyTensorStorage,
3637
ListStorage,
3738
Storage,

torchrl/data/replay_buffers/samplers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,10 @@ def __init__(
10491049
self._cache["stop-and-length"] = vals
10501050

10511051
else:
1052+
if traj_key is not None:
1053+
self._fetch_traj = True
1054+
elif end_key is not None:
1055+
self._fetch_traj = False
10521056
if end_key is None:
10531057
end_key = ("next", "done")
10541058
if traj_key is None:
@@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
13311335
if start_idx.shape[1] != storage.ndim:
13321336
raise RuntimeError(
13331337
f"Expected the end-of-trajectory signal to be "
1334-
f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
1338+
f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
13351339
"instead."
13361340
)
13371341
seq_length, num_slices = self._adjusted_batch_size(batch_size)

torchrl/data/replay_buffers/storages.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,15 @@ def set(
297297
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
298298
if isinstance(index, (INT_CLASSES, slice)):
299299
return self._storage[index]
300+
elif isinstance(index, tuple):
301+
if len(index) > 1:
302+
raise RuntimeError(
303+
f"{type(self).__name__} can only be indexed with one-length tuples."
304+
)
305+
return self.get(index[0])
300306
else:
307+
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
308+
index = index.cpu().tolist()
301309
return [self._storage[i] for i in index]
302310

303311
def __len__(self):
@@ -353,6 +361,77 @@ def contains(self, item):
353361
raise NotImplementedError(f"type {type(item)} is not supported yet.")
354362

355363

364+
class LazyStackStorage(ListStorage):
365+
"""A ListStorage that returns LazyStackTensorDict instances.
366+
367+
This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
368+
It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
369+
lazily stacking items when queried.
370+
This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
371+
Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
372+
Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
373+
the size of the buffer.
374+
375+
If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
376+
(see :mod:`~torch.nested`).
377+
378+
Args:
379+
max_size (int, optional): the maximum number of elements stored in the storage.
380+
If not provided, an unlimited storage is created.
381+
382+
Keyword Args:
383+
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
384+
the cost of being executable in multiprocessed settings.
385+
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`.
386+
387+
Examples:
388+
>>> import torch
389+
>>> from torchrl.data import ReplayBuffer, LazyStackStorage
390+
>>> from tensordict import TensorDict
391+
>>> _ = torch.manual_seed(0)
392+
>>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
393+
>>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
394+
>>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
395+
>>> _ = rb.add(data0)
396+
>>> _ = rb.add(data1)
397+
>>> rb.sample(10)
398+
LazyStackedTensorDict(
399+
fields={
400+
a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
401+
b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
402+
c: NonTensorStack(
403+
['another string!', 'another string!', 'another st...,
404+
batch_size=torch.Size([10]),
405+
device=None)},
406+
exclusive_fields={
407+
},
408+
batch_size=torch.Size([10]),
409+
device=None,
410+
is_shared=False,
411+
stack_dim=0)
412+
"""
413+
414+
def __init__(
415+
self,
416+
max_size: int | None = None,
417+
*,
418+
compilable: bool = False,
419+
stack_dim: int = -1,
420+
):
421+
super().__init__(max_size=max_size, compilable=compilable)
422+
self.stack_dim = stack_dim
423+
424+
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
425+
out = super().get(index=index)
426+
if isinstance(out, list):
427+
stack_dim = self.stack_dim
428+
if stack_dim < 0:
429+
stack_dim = out[0].ndim + 1 + stack_dim
430+
out = LazyStackedTensorDict(*out, stack_dim=stack_dim)
431+
return out
432+
return out
433+
434+
356435
class TensorStorage(Storage):
357436
"""A storage for tensors and tensordicts.
358437

0 commit comments

Comments
 (0)