Skip to content

Commit 873779a

Browse files
authored
[Feature] RandomCropTensorDict transform (#908)
1 parent 705f70f commit 873779a

File tree

9 files changed

+253
-16
lines changed

9 files changed

+253
-16
lines changed

docs/source/reference/data.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ The following mean sampling latency improvements over using ListStorage were fou
5555
| :class:`LazyMemmapStorage` | 3.44x |
5656
+-------------------------------+-----------+
5757

58+
Sotring trajectories
59+
~~~~~~~~~~~~~~~~~~~~
60+
61+
It is not too difficult to store trajecotries in the replay buffer.
62+
One element to pay attention to is that the size of the replay buffer is always
63+
the size of the leading dimension of the storage: in other words, creating a
64+
replay buffer with a storage of size 1M when storing multidimensional data
65+
does not mean storing 1M frames but 1M trajectories.
66+
67+
When sampling trajectories, it may be desirable to sample sub-trajectories
68+
to diversify learning or make the sampling more efficient.
69+
To do this, we provide a custom :class:`torchrl.envs.Transform` class named
70+
:class:`torchrl.envs.RandomCropTensorDict`. Here is an example of how this class
71+
can be used:
72+
73+
>>>
5874

5975
TensorSpec
6076
----------

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ to be able to create this other composition:
270270
ObservationTransform
271271
PinMemoryTransform
272272
R3MTransform
273+
RandomCropTensorDict
273274
Resize
274275
RewardClipping
275276
RewardScaling

test/test_rb.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def test_prototype_prb(priority_key, contiguous, device):
382382

383383

384384
@pytest.mark.parametrize("stack", [False, True])
385-
def test_replay_buffer_trajectories(stack):
385+
@pytest.mark.parametrize("reduction", ["min", "max", "median", "mean"])
386+
def test_replay_buffer_trajectories(stack, reduction):
386387
traj_td = TensorDict(
387388
{"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)},
388389
batch_size=[3, 4],
@@ -395,22 +396,23 @@ def test_replay_buffer_trajectories(stack):
395396
5,
396397
alpha=0.7,
397398
beta=0.9,
399+
reduction=reduction,
398400
),
399401
priority_key="td_error",
400402
)
401403
rb.extend(traj_td)
402404
sampled_td = rb.sample(3)
403-
sampled_td.set("td_error", torch.rand(3))
405+
sampled_td.set("td_error", torch.rand(sampled_td.shape))
404406
rb.update_tensordict_priority(sampled_td)
405407
sampled_td = rb.sample(3, include_info=True)
406408
assert (sampled_td.get("_weight") > 0).all()
407-
assert sampled_td.batch_size == torch.Size([3])
409+
assert sampled_td.batch_size == torch.Size([3, 4])
408410

409-
# set back the trajectory length
410-
sampled_td_filtered = sampled_td.to_tensordict().exclude(
411-
"_weight", "index", "td_error"
412-
)
413-
sampled_td_filtered.batch_size = [3, 4]
411+
# # set back the trajectory length
412+
# sampled_td_filtered = sampled_td.to_tensordict().exclude(
413+
# "_weight", "index", "td_error"
414+
# )
415+
# sampled_td_filtered.batch_size = [3, 4]
414416

415417

416418
@pytest.mark.parametrize(
@@ -660,7 +662,8 @@ def test_prb(priority_key, contiguous, device):
660662

661663

662664
@pytest.mark.parametrize("stack", [False, True])
663-
def test_rb_trajectories(stack):
665+
@pytest.mark.parametrize("reduction", ["min", "max", "mean", "median"])
666+
def test_rb_trajectories(stack, reduction):
664667
traj_td = TensorDict(
665668
{"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)},
666669
batch_size=[3, 4],
@@ -676,11 +679,11 @@ def test_rb_trajectories(stack):
676679
)
677680
rb.extend(traj_td)
678681
sampled_td = rb.sample(3)
679-
sampled_td.set("td_error", torch.rand(3))
682+
sampled_td.set("td_error", torch.rand(3, 4))
680683
rb.update_tensordict_priority(sampled_td)
681684
sampled_td = rb.sample(3, include_info=True)
682685
assert (sampled_td.get("_weight") > 0).all()
683-
assert sampled_td.batch_size == torch.Size([3])
686+
assert sampled_td.batch_size == torch.Size([3, 4])
684687

685688
# set back the trajectory length
686689
sampled_td_filtered = sampled_td.to_tensordict().exclude(

test/test_transforms.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
ParallelEnv,
5757
PinMemoryTransform,
5858
R3MTransform,
59+
RandomCropTensorDict,
5960
Resize,
6061
RewardClipping,
6162
RewardScaling,
@@ -6168,6 +6169,67 @@ def test_clone_parent_compose(transform):
61686169
assert env.transform[1].parent.base_env is base_env1
61696170

61706171

6172+
class TestCroSeq:
6173+
def test_crop_dim1(self):
6174+
tensordict = TensorDict(
6175+
{
6176+
"a": torch.arange(20).view(1, 1, 1, 20).expand(3, 4, 2, 20),
6177+
"b": TensorDict(
6178+
{"c": torch.arange(20).view(1, 1, 1, 20, 1).expand(3, 4, 2, 20, 1)},
6179+
[3, 4, 2, 20, 1],
6180+
),
6181+
},
6182+
[3, 4, 2, 20],
6183+
)
6184+
t = RandomCropTensorDict(11, -1)
6185+
tensordict_crop = t(tensordict)
6186+
assert tensordict_crop.shape == torch.Size([3, 4, 2, 11])
6187+
assert tensordict_crop["b"].shape == torch.Size([3, 4, 2, 11, 1])
6188+
assert (
6189+
tensordict_crop["a"][:, :, :, :-1] + 1 == tensordict_crop["a"][:, :, :, 1:]
6190+
).all()
6191+
6192+
def test_crop_dim2(self):
6193+
tensordict = TensorDict(
6194+
{"a": torch.arange(20).view(1, 1, 20, 1).expand(3, 4, 20, 2)},
6195+
[3, 4, 20, 2],
6196+
)
6197+
t = RandomCropTensorDict(11, -2)
6198+
tensordict_crop = t(tensordict)
6199+
assert tensordict_crop.shape == torch.Size([3, 4, 11, 2])
6200+
assert (
6201+
tensordict_crop["a"][:, :, :-1] + 1 == tensordict_crop["a"][:, :, 1:]
6202+
).all()
6203+
6204+
def test_crop_error(self):
6205+
tensordict = TensorDict(
6206+
{"a": torch.arange(20).view(1, 1, 20, 1).expand(3, 4, 20, 2)},
6207+
[3, 4, 20, 2],
6208+
)
6209+
t = RandomCropTensorDict(21, -2)
6210+
with pytest.raises(RuntimeError, match="Cannot sample trajectories of length"):
6211+
_ = t(tensordict)
6212+
6213+
@pytest.mark.parametrize("mask_key", ("mask", ("collector", "mask")))
6214+
def test_crop_mask(self, mask_key):
6215+
a = torch.arange(20).view(1, 1, 20, 1).expand(3, 4, 20, 2).clone()
6216+
mask = a < 21
6217+
mask[0] = a[0] < 15
6218+
mask[1] = a[1] < 16
6219+
mask[1] = a[2] < 14
6220+
tensordict = TensorDict(
6221+
{"a": a, mask_key: mask},
6222+
[3, 4, 20, 2],
6223+
)
6224+
t = RandomCropTensorDict(15, -2, mask_key=mask_key)
6225+
with pytest.raises(RuntimeError, match="Cannot sample trajectories of length"):
6226+
_ = t(tensordict)
6227+
t = RandomCropTensorDict(13, -2, mask_key=mask_key)
6228+
tensordict_crop = t(tensordict)
6229+
assert tensordict_crop.shape == torch.Size([3, 4, 13, 2])
6230+
assert tensordict_crop[mask_key].all()
6231+
6232+
61716233
if __name__ == "__main__":
61726234
args, unknown = argparse.ArgumentParser().parse_known_args()
61736235
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
13+
from tensordict.utils import expand_right
1314

1415
from torchrl.data.utils import DEVICE_TYPING
1516

@@ -351,7 +352,11 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
351352
tensordict = tensordict.clone(recurse=False)
352353
tensordict.batch_size = []
353354
try:
354-
priority = tensordict.get(self.priority_key).item()
355+
priority = tensordict.get(self.priority_key)
356+
if priority.numel() > 1:
357+
priority = _reduce(priority, self._sampler.reduction)
358+
else:
359+
priority = priority.item()
355360
except ValueError:
356361
raise ValueError(
357362
f"Found a priority key of size"
@@ -378,7 +383,16 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
378383
tensordicts = tensordicts.clone(recurse=False)
379384
else:
380385
tensordicts = tensordicts.contiguous()
386+
# we keep track of the batch size to reinstantiate it when sampling
387+
if "_batch_size" in tensordicts.keys():
388+
raise KeyError(
389+
"conflicting key '_batch_size'. Consider removing from data."
390+
)
391+
shape = torch.tensor(tensordicts.batch_size[1:]).expand(
392+
tensordicts.batch_size[0], tensordicts.batch_dims - 1
393+
)
381394
tensordicts.batch_size = tensordicts.batch_size[:1]
395+
tensordicts.set("_batch_size", shape)
382396
tensordicts.set(
383397
"index",
384398
torch.zeros(
@@ -406,7 +420,13 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
406420
dtype=torch.float,
407421
device=data.device,
408422
)
409-
self.update_priority(data.get("index"), priority)
423+
# if the index shape does not match the priority shape, we have expanded it.
424+
# we just take the first value
425+
index = data.get("index")
426+
while index.shape != priority.shape:
427+
# reduce index
428+
index = index[..., 0]
429+
self.update_priority(index, priority)
410430

411431
def sample(
412432
self, batch_size: int, include_info: bool = False, return_info: bool = False
@@ -429,6 +449,18 @@ def sample(
429449
if include_info:
430450
for k, v in info.items():
431451
data.set(k, torch.tensor(v, device=data.device), inplace=True)
452+
if "_batch_size" in data.keys():
453+
# we need to reset the batch-size
454+
shape = data.pop("_batch_size")
455+
shape = shape[0]
456+
shape = torch.Size([data.shape[0], *shape])
457+
# we may need to update some values in the data
458+
for key, value in data.items():
459+
if value.ndim >= len(shape):
460+
continue
461+
value = expand_right(value, shape)
462+
data.set(key, value)
463+
data.batch_size = shape
432464
if return_info:
433465
return data, info
434466
return data
@@ -462,6 +494,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
462494
using multithreading.
463495
transform (Transform, optional): Transform to be executed when sample() is called.
464496
To chain transforms use the :obj:`Compose` class.
497+
reduction (str, optional): the reduction method for multidimensional
498+
tensordicts (ie stored trajectories). Can be one of "max", "min",
499+
"median" or "mean".
465500
"""
466501

467502
def __init__(
@@ -475,10 +510,13 @@ def __init__(
475510
pin_memory: bool = False,
476511
prefetch: Optional[int] = None,
477512
transform: Optional["Transform"] = None, # noqa-F821
513+
reduction: Optional[str] = "max",
478514
) -> None:
479515
if storage is None:
480516
storage = ListStorage(max_size=1_000)
481-
sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps)
517+
sampler = PrioritizedSampler(
518+
storage.max_size, alpha, beta, eps, reduction=reduction
519+
)
482520
super(TensorDictPrioritizedReplayBuffer, self).__init__(
483521
priority_key=priority_key,
484522
storage=storage,
@@ -539,3 +577,16 @@ def __call__(self, list_of_tds):
539577
else:
540578
torch.stack(list_of_tds, 0, out=self.out)
541579
return self.out
580+
581+
582+
def _reduce(tensor: torch.Tensor, reduction: str):
583+
"""Reduces a tensor given the reduction method."""
584+
if reduction == "max":
585+
return tensor.max().item()
586+
elif reduction == "min":
587+
return tensor.min().item()
588+
elif reduction == "mean":
589+
return tensor.mean().item()
590+
elif reduction == "median":
591+
return tensor.median().item()
592+
raise NotImplementedError(f"Unknown reduction method {reduction}")

torchrl/data/replay_buffers/samplers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,11 @@ class PrioritizedSampler(Sampler):
132132
alpha (float): exponent α determines how much prioritization is used,
133133
with α = 0 corresponding to the uniform case.
134134
beta (float): importance sampling negative exponent.
135-
eps (float): delta added to the priorities to ensure that the buffer
136-
does not contain null priorities.
135+
eps (float, optional): delta added to the priorities to ensure that the buffer
136+
does not contain null priorities. Defaults to 1e-8.
137+
reduction (str, optional): the reduction method for multidimensional
138+
tensordicts (ie stored trajectories). Can be one of "max", "min",
139+
"median" or "mean".
137140
138141
"""
139142

@@ -144,6 +147,7 @@ def __init__(
144147
beta: float,
145148
eps: float = 1e-8,
146149
dtype: torch.dtype = torch.float,
150+
reduction: str = "max",
147151
) -> None:
148152
if alpha <= 0:
149153
raise ValueError(
@@ -156,6 +160,7 @@ def __init__(
156160
self._alpha = alpha
157161
self._beta = beta
158162
self._eps = eps
163+
self.reduction = reduction
159164
if dtype in (torch.float, torch.FloatType, torch.float32):
160165
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
161166
self._min_tree = MinSegmentTreeFp32(self._max_capacity)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ObservationTransform,
2727
PinMemoryTransform,
2828
R3MTransform,
29+
RandomCropTensorDict,
2930
Resize,
3031
RewardClipping,
3132
RewardScaling,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ObservationNorm,
2323
ObservationTransform,
2424
PinMemoryTransform,
25+
RandomCropTensorDict,
2526
Resize,
2627
RewardClipping,
2728
RewardScaling,

0 commit comments

Comments
 (0)