Skip to content

Commit 9ee1ae7

Browse files
author
Vincent Moens
committed
[Feature] CatFrames.make_rb_transform_and_sampler
ghstack-source-id: 7ecf952 Pull Request resolved: #2643
1 parent 17983d4 commit 9ee1ae7

File tree

4 files changed

+209
-3
lines changed

4 files changed

+209
-3
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from torchrl.data import LazyTensorStorage, ReplayBuffer
8+
from torchrl.envs import (
9+
CatFrames,
10+
Compose,
11+
DMControlEnv,
12+
StepCounter,
13+
ToTensorImage,
14+
TransformedEnv,
15+
UnsqueezeTransform,
16+
)
17+
18+
# Number of frames to stack together
19+
frame_stack = 4
20+
# Dimension along which the stack should occur
21+
stack_dim = -4
22+
# Max size of the buffer
23+
max_size = 100_000
24+
# Batch size of the replay buffer
25+
training_batch_size = 32
26+
27+
seed = 123
28+
29+
30+
def main():
31+
catframes = CatFrames(
32+
N=frame_stack,
33+
dim=stack_dim,
34+
in_keys=["pixels_trsf"],
35+
out_keys=["pixels_trsf"],
36+
)
37+
env = TransformedEnv(
38+
DMControlEnv(
39+
env_name="cartpole",
40+
task_name="balance",
41+
device="cpu",
42+
from_pixels=True,
43+
pixels_only=True,
44+
),
45+
Compose(
46+
ToTensorImage(
47+
from_int=True,
48+
dtype=torch.float32,
49+
in_keys=["pixels"],
50+
out_keys=["pixels_trsf"],
51+
shape_tolerant=True,
52+
),
53+
UnsqueezeTransform(
54+
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
55+
),
56+
catframes,
57+
StepCounter(),
58+
),
59+
)
60+
env.set_seed(seed)
61+
62+
transform, sampler = catframes.make_rb_transform_and_sampler(
63+
batch_size=training_batch_size,
64+
traj_key=("collector", "traj_ids"),
65+
strict_length=True,
66+
)
67+
68+
rb_transforms = Compose(
69+
ToTensorImage(
70+
from_int=True,
71+
dtype=torch.float32,
72+
in_keys=["pixels", ("next", "pixels")],
73+
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
74+
shape_tolerant=True,
75+
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
76+
UnsqueezeTransform(
77+
dim=stack_dim,
78+
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
79+
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
80+
), # 1 C W' H'
81+
transform,
82+
)
83+
84+
rb = ReplayBuffer(
85+
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
86+
sampler=sampler,
87+
batch_size=training_batch_size,
88+
transform=rb_transforms,
89+
)
90+
91+
data = env.rollout(1000, break_when_any_done=False)
92+
rb.extend(data)
93+
94+
training_batch = rb.sample()
95+
print(training_batch)
96+
97+
98+
if __name__ == "__main__":
99+
main()

test/test_transforms.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,29 @@ def test_transform_rb(self, dim, N, padding, rbclass):
933933
assert (tdsample["out_" + key1] == td["out_" + key1]).all()
934934
assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all()
935935

936+
def test_transform_rb_maker(self):
937+
env = CountingEnv(max_steps=10)
938+
catframes = CatFrames(
939+
in_keys=["observation"], out_keys=["observation_stack"], dim=-1, N=4
940+
)
941+
env.append_transform(catframes)
942+
policy = lambda td: td.update(env.full_action_spec.zeros() + 1)
943+
rollout = env.rollout(150, policy, break_when_any_done=False)
944+
transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
945+
rb = ReplayBuffer(
946+
sampler=sampler, storage=LazyTensorStorage(150), transform=transform
947+
)
948+
rb.extend(rollout)
949+
sample = rb.sample(32)
950+
assert "observation_stack" not in rb._storage._storage
951+
assert sample.shape == (32,)
952+
assert sample["observation_stack"].shape == (32, 4)
953+
assert sample["next", "observation_stack"].shape == (32, 4)
954+
assert (
955+
sample["observation_stack"]
956+
== sample["observation_stack"][:, :1] + torch.arange(4)
957+
).all()
958+
936959
@pytest.mark.parametrize("dim", [-1])
937960
@pytest.mark.parametrize("N", [3, 4])
938961
@pytest.mark.parametrize("padding", ["same", "constant"])

torchrl/data/replay_buffers/samplers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,9 @@ class SliceSampler(Sampler):
968968
969969
"""
970970

971+
# We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
972+
_batch_size_multiplier: int | None = 1
973+
971974
def __init__(
972975
self,
973976
*,
@@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size):
12951298
return seq_length, num_slices
12961299

12971300
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
1301+
if self._batch_size_multiplier is not None:
1302+
batch_size = batch_size * self._batch_size_multiplier
12981303
# pick up as many trajs as we need
12991304
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
13001305
# we have to make sure that the number of dims of the storage
@@ -1747,6 +1752,8 @@ def _storage_len(self, storage):
17471752
def sample(
17481753
self, storage: Storage, batch_size: int
17491754
) -> Tuple[Tuple[torch.Tensor, ...], dict]:
1755+
if self._batch_size_multiplier is not None:
1756+
batch_size = batch_size * self._batch_size_multiplier
17501757
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
17511758
# we have to make sure that the number of dims of the storage
17521759
# is the same as the stop/start signals since we will

torchrl/envs/transforms/transforms.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,9 +2825,9 @@ def _reset(
28252825
class CatFrames(ObservationTransform):
28262826
"""Concatenates successive observation frames into a single tensor.
28272827
2828-
This can, for instance, account for movement/velocity of the observed
2829-
feature. Proposed in "Playing Atari with Deep Reinforcement Learning" (
2830-
https://arxiv.org/pdf/1312.5602.pdf).
2828+
This transform is useful for creating a sense of movement or velocity in the observed features.
2829+
It can also be used with models that require access to past observations such as transformers and the like.
2830+
It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf).
28312831
28322832
When used within a transformed environment,
28332833
:class:`CatFrames` is a stateful class, and it can be reset to its native state by
@@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform):
29152915
such as those found in MARL settings, are currently not supported.
29162916
If this feature is needed, please raise an issue on TorchRL repo.
29172917
2918+
.. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times).
2919+
To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time.
2920+
This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform.
2921+
For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates:
2922+
2923+
- A modified version of the transform suitable for use in replay buffers
2924+
- A corresponding :class:`SliceSampler` to use with the buffer
2925+
29182926
"""
29192927

29202928
inplace = False
@@ -2964,6 +2972,75 @@ def __init__(
29642972
self.reset_key = reset_key
29652973
self.done_key = done_key
29662974

2975+
def make_rb_transform_and_sampler(
2976+
self, batch_size: int, **sampler_kwargs
2977+
) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821
2978+
"""Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.
2979+
2980+
This method helps reduce redundancy in stored data by avoiding the need to
2981+
store the entire stack of frames in the buffer. Instead, it creates a
2982+
transform that stacks frames on-the-fly during sampling, and a sampler that
2983+
ensures the correct sequence length is maintained.
2984+
2985+
Args:
2986+
batch_size (int): The batch size to use for the sampler.
2987+
**sampler_kwargs: Additional keyword arguments to pass to the
2988+
:class:`~torchrl.data.replay_buffers.SliceSampler` constructor.
2989+
2990+
Returns:
2991+
A tuple containing:
2992+
- transform (Transform): A transform that stacks frames on-the-fly during sampling.
2993+
- sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.
2994+
2995+
Example:
2996+
>>> env = TransformedEnv(...)
2997+
>>> catframes = CatFrames(N=4, ...)
2998+
>>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
2999+
>>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
3000+
3001+
.. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding
3002+
:class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate
3003+
from their processed counterparts, which we don't want to store.
3004+
For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create
3005+
a copy of the data that will be stored in the buffer.
3006+
3007+
.. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms
3008+
that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform`
3009+
in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data
3010+
collection.
3011+
3012+
.. note:: For a more complete example, refer to torchrl's github repo `examples` folder:
3013+
https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
3014+
3015+
"""
3016+
from torchrl.data.replay_buffers import SliceSampler
3017+
3018+
in_keys = self.in_keys
3019+
in_keys = in_keys + [unravel_key(("next", key)) for key in in_keys]
3020+
out_keys = self.out_keys
3021+
out_keys = out_keys + [unravel_key(("next", key)) for key in out_keys]
3022+
catframes = type(self)(
3023+
N=self.N,
3024+
in_keys=in_keys,
3025+
out_keys=out_keys,
3026+
dim=self.dim,
3027+
padding=self.padding,
3028+
padding_value=self.padding_value,
3029+
as_inverse=False,
3030+
reset_key=self.reset_key,
3031+
done_key=self.done_key,
3032+
)
3033+
sampler = SliceSampler(slice_len=self.N, **sampler_kwargs)
3034+
sampler._batch_size_multiplier = self.N
3035+
transform = Compose(
3036+
lambda td: td.reshape(-1, self.N),
3037+
catframes,
3038+
lambda td: td[:, -1],
3039+
# We only store "pixels" to the replay buffer to save memory
3040+
ExcludeTransform(*out_keys, inverse=True),
3041+
)
3042+
return transform, sampler
3043+
29673044
@property
29683045
def done_key(self):
29693046
done_key = self.__dict__.get("_done_key", None)

0 commit comments

Comments
 (0)