|
38 | 38 | from torch import Tensor
|
39 | 39 | from torch.utils._pytree import tree_map
|
40 | 40 |
|
41 |
| -from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation |
| 41 | +from torchrl._utils import accept_remote_rref_udf_invocation |
42 | 42 | from torchrl.data.replay_buffers.samplers import (
|
43 | 43 | PrioritizedSampler,
|
44 | 44 | RandomSampler,
|
@@ -1574,33 +1574,12 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
|
1574 | 1574 |
|
1575 | 1575 |
|
1576 | 1576 | class InPlaceSampler:
|
1577 |
| - """A sampler to write tennsordicts in-place. |
1578 |
| -
|
1579 |
| - .. warning:: This class is deprecated and will be removed in v0.7. |
1580 |
| -
|
1581 |
| - To be used cautiously as this may lead to unexpected behavior (i.e. tensordicts |
1582 |
| - overwritten during execution). |
1583 |
| -
|
1584 |
| - """ |
| 1577 | + """[Deprecated] A sampler to write tennsordicts in-place.""" |
1585 | 1578 |
|
1586 | 1579 | def __init__(self, device: DEVICE_TYPING | None = None):
|
1587 |
| - warnings.warn( |
1588 |
| - "InPlaceSampler has been deprecated and will be removed in v0.7.", |
1589 |
| - category=DeprecationWarning, |
| 1580 | + raise RuntimeError( |
| 1581 | + "This class has been removed without replacement. In-place sampling should be avoided." |
1590 | 1582 | )
|
1591 |
| - self.out = None |
1592 |
| - if device is None: |
1593 |
| - device = "cpu" |
1594 |
| - self.device = _make_ordinal_device(torch.device(device)) |
1595 |
| - |
1596 |
| - def __call__(self, list_of_tds): |
1597 |
| - if self.out is None: |
1598 |
| - self.out = torch.stack(list_of_tds, 0).contiguous() |
1599 |
| - if self.device is not None: |
1600 |
| - self.out = self.out.to(self.device) |
1601 |
| - else: |
1602 |
| - torch.stack(list_of_tds, 0, out=self.out) |
1603 |
| - return self.out |
1604 | 1583 |
|
1605 | 1584 |
|
1606 | 1585 | def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
|
|
0 commit comments