Skip to content

Commit 84c3ec3

Browse files
author
Vincent Moens
committed
[Performance] Accelerate slice sampler on GPU
ghstack-source-id: a4dc151 Pull Request resolved: #2672
1 parent 4fd54fe commit 84c3ec3

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from torchrl._utils import _replace_last, logger
2626
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
27-
from torchrl.data.replay_buffers.utils import _is_int, unravel_index
27+
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
2828

2929
try:
3030
from torchrl._torchrl import (
@@ -726,6 +726,10 @@ class SliceSampler(Sampler):
726726
This class samples sub-trajectories with replacement. For a version without
727727
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
728728
729+
.. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
730+
its execution, prefer using `end_key` over `traj_key`, and consider the following
731+
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
732+
729733
Keyword Args:
730734
num_slices (int): the number of slices to be sampled. The batch-size
731735
must be greater or equal to the ``num_slices`` argument. Exclusive
@@ -796,6 +800,10 @@ class SliceSampler(Sampler):
796800
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
797801
Using tuples allows a fine grained control over the span on the left (beginning
798802
of the stored trajectory) and on the right (end of the stored trajectory).
803+
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
804+
will be used to retrieve the indices of the trajectory starts. This can significanlty
805+
accelerate the sampling when the buffer content is large.
806+
Defaults to ``False``.
799807
800808
.. note:: To recover the trajectory splits in the storage,
801809
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
@@ -985,6 +993,7 @@ def __init__(
985993
strict_length: bool = True,
986994
compile: bool | dict = False,
987995
span: bool | int | Tuple[bool | int, bool | int] = False,
996+
use_gpu: torch.device | bool = False,
988997
):
989998
self.num_slices = num_slices
990999
self.slice_len = slice_len
@@ -995,6 +1004,14 @@ def __init__(
9951004
self._fetch_traj = True
9961005
self.strict_length = strict_length
9971006
self._cache = {}
1007+
self.use_gpu = bool(use_gpu)
1008+
self._gpu_device = (
1009+
None
1010+
if not self.use_gpu
1011+
else torch.device(use_gpu)
1012+
if not isinstance(use_gpu, bool)
1013+
else _auto_device()
1014+
)
9981015

9991016
if isinstance(span, (bool, int)):
10001017
span = (span, span)
@@ -1086,9 +1103,8 @@ def __repr__(self):
10861103
f"strict_length={self.strict_length})"
10871104
)
10881105

1089-
@classmethod
10901106
def _find_start_stop_traj(
1091-
cls, *, trajectory=None, end=None, at_capacity: bool, cursor=None
1107+
self, *, trajectory=None, end=None, at_capacity: bool, cursor=None
10921108
):
10931109
if trajectory is not None:
10941110
# slower
@@ -1141,10 +1157,15 @@ def _find_start_stop_traj(
11411157
raise RuntimeError(
11421158
"Expected the end-of-trajectory signal to be at least 1-dimensional."
11431159
)
1144-
return cls._end_to_start_stop(length=length, end=end)
1145-
1146-
@staticmethod
1147-
def _end_to_start_stop(end, length):
1160+
return self._end_to_start_stop(length=length, end=end)
1161+
1162+
def _end_to_start_stop(self, end, length):
1163+
device = None
1164+
if self.use_gpu:
1165+
gpu_device = self._gpu_device
1166+
if end.device != gpu_device:
1167+
device = end.device
1168+
end = end.to(self._gpu_device)
11481169
# Using transpose ensures the start and stop are sorted the same way
11491170
stop_idx = end.transpose(0, -1).nonzero()
11501171
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
@@ -1171,6 +1192,8 @@ def _end_to_start_stop(end, length):
11711192
pass
11721193
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
11731194
lengths[lengths <= 0] = lengths[lengths <= 0] + length
1195+
if device is not None:
1196+
return start_idx.to(device), stop_idx.to(device), lengths.to(device)
11741197
return start_idx, stop_idx, lengths
11751198

11761199
def _start_to_end(self, st: torch.Tensor, length: int):
@@ -1547,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
15471570
the sampler, and continuous sampling without replacement is currently not
15481571
allowed.
15491572
1573+
.. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
1574+
its execution, prefer using `end_key` over `traj_key`, and consider the following
1575+
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
1576+
15501577
Keyword Args:
15511578
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
15521579
If ``False``, this last sample will be kept.
@@ -1589,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
15891616
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
15901617
Keyword arguments can also be passed to torch.compile with this arg.
15911618
Defaults to ``False``.
1619+
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
1620+
will be used to retrieve the indices of the trajectory starts. This can significanlty
1621+
accelerate the sampling when the buffer content is large.
1622+
Defaults to ``False``.
15921623
15931624
.. note:: To recover the trajectory splits in the storage,
15941625
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
@@ -1693,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
16931724
tensor([[0., 0., 0., 0., 0.],
16941725
[1., 1., 1., 1., 1.]])
16951726
1696-
16971727
"""
16981728

16991729
def __init__(
@@ -1710,6 +1740,7 @@ def __init__(
17101740
strict_length: bool = True,
17111741
shuffle: bool = True,
17121742
compile: bool | dict = False,
1743+
use_gpu: bool | torch.device = False,
17131744
):
17141745
SliceSampler.__init__(
17151746
self,
@@ -1723,6 +1754,7 @@ def __init__(
17231754
ends=ends,
17241755
trajectories=trajectories,
17251756
compile=compile,
1757+
use_gpu=use_gpu,
17261758
)
17271759
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)
17281760

torchrl/data/replay_buffers/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811
10341034
def tree_iter(pytree): # noqa: F811
10351035
"""A version-compatible wrapper around tree_iter."""
10361036
yield from torch.utils._pytree.tree_iter(pytree)
1037+
1038+
1039+
def _auto_device() -> torch.device:
1040+
if torch.cuda.is_available():
1041+
return torch.device("cuda:0")
1042+
elif torch.mps.is_available():
1043+
return torch.device("mps:0")
1044+
return torch.device("cpu")

0 commit comments

Comments
 (0)