24
24
25
25
from torchrl ._utils import _replace_last , logger
26
26
from torchrl .data .replay_buffers .storages import Storage , StorageEnsemble , TensorStorage
27
- from torchrl .data .replay_buffers .utils import _is_int , unravel_index
27
+ from torchrl .data .replay_buffers .utils import _auto_device , _is_int , unravel_index
28
28
29
29
try :
30
30
from torchrl ._torchrl import (
@@ -726,6 +726,10 @@ class SliceSampler(Sampler):
726
726
This class samples sub-trajectories with replacement. For a version without
727
727
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
728
728
729
+ .. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
730
+ its execution, prefer using `end_key` over `traj_key`, and consider the following
731
+ keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
732
+
729
733
Keyword Args:
730
734
num_slices (int): the number of slices to be sampled. The batch-size
731
735
must be greater or equal to the ``num_slices`` argument. Exclusive
@@ -796,6 +800,10 @@ class SliceSampler(Sampler):
796
800
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
797
801
Using tuples allows a fine grained control over the span on the left (beginning
798
802
of the stored trajectory) and on the right (end of the stored trajectory).
803
+ use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
804
+ will be used to retrieve the indices of the trajectory starts. This can significanlty
805
+ accelerate the sampling when the buffer content is large.
806
+ Defaults to ``False``.
799
807
800
808
.. note:: To recover the trajectory splits in the storage,
801
809
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
@@ -985,6 +993,7 @@ def __init__(
985
993
strict_length : bool = True ,
986
994
compile : bool | dict = False ,
987
995
span : bool | int | Tuple [bool | int , bool | int ] = False ,
996
+ use_gpu : torch .device | bool = False ,
988
997
):
989
998
self .num_slices = num_slices
990
999
self .slice_len = slice_len
@@ -995,6 +1004,14 @@ def __init__(
995
1004
self ._fetch_traj = True
996
1005
self .strict_length = strict_length
997
1006
self ._cache = {}
1007
+ self .use_gpu = bool (use_gpu )
1008
+ self ._gpu_device = (
1009
+ None
1010
+ if not self .use_gpu
1011
+ else torch .device (use_gpu )
1012
+ if not isinstance (use_gpu , bool )
1013
+ else _auto_device ()
1014
+ )
998
1015
999
1016
if isinstance (span , (bool , int )):
1000
1017
span = (span , span )
@@ -1086,9 +1103,8 @@ def __repr__(self):
1086
1103
f"strict_length={ self .strict_length } )"
1087
1104
)
1088
1105
1089
- @classmethod
1090
1106
def _find_start_stop_traj (
1091
- cls , * , trajectory = None , end = None , at_capacity : bool , cursor = None
1107
+ self , * , trajectory = None , end = None , at_capacity : bool , cursor = None
1092
1108
):
1093
1109
if trajectory is not None :
1094
1110
# slower
@@ -1141,10 +1157,15 @@ def _find_start_stop_traj(
1141
1157
raise RuntimeError (
1142
1158
"Expected the end-of-trajectory signal to be at least 1-dimensional."
1143
1159
)
1144
- return cls ._end_to_start_stop (length = length , end = end )
1145
-
1146
- @staticmethod
1147
- def _end_to_start_stop (end , length ):
1160
+ return self ._end_to_start_stop (length = length , end = end )
1161
+
1162
+ def _end_to_start_stop (self , end , length ):
1163
+ device = None
1164
+ if self .use_gpu :
1165
+ gpu_device = self ._gpu_device
1166
+ if end .device != gpu_device :
1167
+ device = end .device
1168
+ end = end .to (self ._gpu_device )
1148
1169
# Using transpose ensures the start and stop are sorted the same way
1149
1170
stop_idx = end .transpose (0 , - 1 ).nonzero ()
1150
1171
stop_idx [:, [0 , - 1 ]] = stop_idx [:, [- 1 , 0 ]].clone ()
@@ -1171,6 +1192,8 @@ def _end_to_start_stop(end, length):
1171
1192
pass
1172
1193
lengths = stop_idx [:, 0 ] - start_idx [:, 0 ] + 1
1173
1194
lengths [lengths <= 0 ] = lengths [lengths <= 0 ] + length
1195
+ if device is not None :
1196
+ return start_idx .to (device ), stop_idx .to (device ), lengths .to (device )
1174
1197
return start_idx , stop_idx , lengths
1175
1198
1176
1199
def _start_to_end (self , st : torch .Tensor , length : int ):
@@ -1547,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1547
1570
the sampler, and continuous sampling without replacement is currently not
1548
1571
allowed.
1549
1572
1573
+ .. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
1574
+ its execution, prefer using `end_key` over `traj_key`, and consider the following
1575
+ keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
1576
+
1550
1577
Keyword Args:
1551
1578
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
1552
1579
If ``False``, this last sample will be kept.
@@ -1589,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1589
1616
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
1590
1617
Keyword arguments can also be passed to torch.compile with this arg.
1591
1618
Defaults to ``False``.
1619
+ use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
1620
+ will be used to retrieve the indices of the trajectory starts. This can significanlty
1621
+ accelerate the sampling when the buffer content is large.
1622
+ Defaults to ``False``.
1592
1623
1593
1624
.. note:: To recover the trajectory splits in the storage,
1594
1625
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
@@ -1693,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1693
1724
tensor([[0., 0., 0., 0., 0.],
1694
1725
[1., 1., 1., 1., 1.]])
1695
1726
1696
-
1697
1727
"""
1698
1728
1699
1729
def __init__ (
@@ -1710,6 +1740,7 @@ def __init__(
1710
1740
strict_length : bool = True ,
1711
1741
shuffle : bool = True ,
1712
1742
compile : bool | dict = False ,
1743
+ use_gpu : bool | torch .device = False ,
1713
1744
):
1714
1745
SliceSampler .__init__ (
1715
1746
self ,
@@ -1723,6 +1754,7 @@ def __init__(
1723
1754
ends = ends ,
1724
1755
trajectories = trajectories ,
1725
1756
compile = compile ,
1757
+ use_gpu = use_gpu ,
1726
1758
)
1727
1759
SamplerWithoutReplacement .__init__ (self , drop_last = drop_last , shuffle = shuffle )
1728
1760
0 commit comments