@@ -2825,9 +2825,9 @@ def _reset(
2825
2825
class CatFrames (ObservationTransform ):
2826
2826
"""Concatenates successive observation frames into a single tensor.
2827
2827
2828
- This can, for instance, account for movement/ velocity of the observed
2829
- feature. Proposed in "Playing Atari with Deep Reinforcement Learning" (
2830
- https://arxiv.org/pdf/1312.5602.pdf).
2828
+ This transform is useful for creating a sense of movement or velocity in the observed features.
2829
+ It can also be used with models that require access to past observations such as transformers and the like.
2830
+ It was first proposed in "Playing Atari with Deep Reinforcement Learning" ( https://arxiv.org/pdf/1312.5602.pdf).
2831
2831
2832
2832
When used within a transformed environment,
2833
2833
:class:`CatFrames` is a stateful class, and it can be reset to its native state by
@@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform):
2915
2915
such as those found in MARL settings, are currently not supported.
2916
2916
If this feature is needed, please raise an issue on TorchRL repo.
2917
2917
2918
+ .. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times).
2919
+ To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time.
2920
+ This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform.
2921
+ For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates:
2922
+
2923
+ - A modified version of the transform suitable for use in replay buffers
2924
+ - A corresponding :class:`SliceSampler` to use with the buffer
2925
+
2918
2926
"""
2919
2927
2920
2928
inplace = False
@@ -2964,6 +2972,75 @@ def __init__(
2964
2972
self .reset_key = reset_key
2965
2973
self .done_key = done_key
2966
2974
2975
+ def make_rb_transform_and_sampler (
2976
+ self , batch_size : int , ** sampler_kwargs
2977
+ ) -> Tuple [Transform , "torchrl.data.replay_buffers.SliceSampler" ]: # noqa: F821
2978
+ """Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.
2979
+
2980
+ This method helps reduce redundancy in stored data by avoiding the need to
2981
+ store the entire stack of frames in the buffer. Instead, it creates a
2982
+ transform that stacks frames on-the-fly during sampling, and a sampler that
2983
+ ensures the correct sequence length is maintained.
2984
+
2985
+ Args:
2986
+ batch_size (int): The batch size to use for the sampler.
2987
+ **sampler_kwargs: Additional keyword arguments to pass to the
2988
+ :class:`~torchrl.data.replay_buffers.SliceSampler` constructor.
2989
+
2990
+ Returns:
2991
+ A tuple containing:
2992
+ - transform (Transform): A transform that stacks frames on-the-fly during sampling.
2993
+ - sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.
2994
+
2995
+ Example:
2996
+ >>> env = TransformedEnv(...)
2997
+ >>> catframes = CatFrames(N=4, ...)
2998
+ >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
2999
+ >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
3000
+
3001
+ .. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding
3002
+ :class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate
3003
+ from their processed counterparts, which we don't want to store.
3004
+ For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create
3005
+ a copy of the data that will be stored in the buffer.
3006
+
3007
+ .. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms
3008
+ that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform`
3009
+ in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data
3010
+ collection.
3011
+
3012
+ .. note:: For a more complete example, refer to torchrl's github repo `examples` folder:
3013
+ https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
3014
+
3015
+ """
3016
+ from torchrl .data .replay_buffers import SliceSampler
3017
+
3018
+ in_keys = self .in_keys
3019
+ in_keys = in_keys + [unravel_key (("next" , key )) for key in in_keys ]
3020
+ out_keys = self .out_keys
3021
+ out_keys = out_keys + [unravel_key (("next" , key )) for key in out_keys ]
3022
+ catframes = type (self )(
3023
+ N = self .N ,
3024
+ in_keys = in_keys ,
3025
+ out_keys = out_keys ,
3026
+ dim = self .dim ,
3027
+ padding = self .padding ,
3028
+ padding_value = self .padding_value ,
3029
+ as_inverse = False ,
3030
+ reset_key = self .reset_key ,
3031
+ done_key = self .done_key ,
3032
+ )
3033
+ sampler = SliceSampler (slice_len = self .N , ** sampler_kwargs )
3034
+ sampler ._batch_size_multiplier = self .N
3035
+ transform = Compose (
3036
+ lambda td : td .reshape (- 1 , self .N ),
3037
+ catframes ,
3038
+ lambda td : td [:, - 1 ],
3039
+ # We only store "pixels" to the replay buffer to save memory
3040
+ ExcludeTransform (* out_keys , inverse = True ),
3041
+ )
3042
+ return transform , sampler
3043
+
2967
3044
@property
2968
3045
def done_key (self ):
2969
3046
done_key = self .__dict__ .get ("_done_key" , None )
0 commit comments