Skip to content

Commit 0feef11

Browse files
author
Vincent Moens
committed
[Deprecation] Remove InPlaceSampler
ghstack-source-id: eeae1bf Pull Request resolved: #2750
1 parent 0111a87 commit 0feef11

File tree

1 file changed

+4
-25
lines changed

1 file changed

+4
-25
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from torch import Tensor
3939
from torch.utils._pytree import tree_map
4040

41-
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
41+
from torchrl._utils import accept_remote_rref_udf_invocation
4242
from torchrl.data.replay_buffers.samplers import (
4343
PrioritizedSampler,
4444
RandomSampler,
@@ -1574,33 +1574,12 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
15741574

15751575

15761576
class InPlaceSampler:
1577-
"""A sampler to write tennsordicts in-place.
1578-
1579-
.. warning:: This class is deprecated and will be removed in v0.7.
1580-
1581-
To be used cautiously as this may lead to unexpected behavior (i.e. tensordicts
1582-
overwritten during execution).
1583-
1584-
"""
1577+
"""[Deprecated] A sampler to write tennsordicts in-place."""
15851578

15861579
def __init__(self, device: DEVICE_TYPING | None = None):
1587-
warnings.warn(
1588-
"InPlaceSampler has been deprecated and will be removed in v0.7.",
1589-
category=DeprecationWarning,
1580+
raise RuntimeError(
1581+
"This class has been removed without replacement. In-place sampling should be avoided."
15901582
)
1591-
self.out = None
1592-
if device is None:
1593-
device = "cpu"
1594-
self.device = _make_ordinal_device(torch.device(device))
1595-
1596-
def __call__(self, list_of_tds):
1597-
if self.out is None:
1598-
self.out = torch.stack(list_of_tds, 0).contiguous()
1599-
if self.device is not None:
1600-
self.out = self.out.to(self.device)
1601-
else:
1602-
torch.stack(list_of_tds, 0, out=self.out)
1603-
return self.out
16041583

16051584

16061585
def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:

0 commit comments

Comments
 (0)