Skip to content

Commit 70c650e

Browse files
author
Vincent Moens
authored
[Deprecation] Deprecate ambiguous device for memmap replay buffer (#1624)
1 parent 5e81445 commit 70c650e

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

torchrl/data/replay_buffers/storages.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
586586
data.clone()
587587
.expand(self.max_size, *data.shape)
588588
.memmap_like(prefix=self.scratch_dir)
589-
.to(self.device)
590589
)
590+
if self.device.type != "cpu":
591+
warnings.warn(
592+
"Support for Memmap device other than CPU will be deprecated in v0.4.0.",
593+
category=DeprecationWarning,
594+
)
595+
out = out.to(self.device).memmap_()
596+
591597
for key, tensor in sorted(
592598
out.items(include_nested=True, leaves_only=True), key=str
593599
):
@@ -603,8 +609,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
603609
data.clone()
604610
.expand(self.max_size, *data.shape)
605611
.memmap_like(prefix=self.scratch_dir)
606-
.to(self.device)
607612
)
613+
if self.device.type != "cpu":
614+
warnings.warn(
615+
"Support for Memmap device other than CPU will be deprecated in v0.4.0.",
616+
category=DeprecationWarning,
617+
)
618+
out = out.to(self.device).memmap_()
619+
608620
for key, tensor in sorted(
609621
out.items(include_nested=True, leaves_only=True), key=str
610622
):

tutorials/sphinx-tutorials/pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
#
8989
from torchrl.data import LazyMemmapStorage, ReplayBuffer
9090

91-
storage = LazyMemmapStorage(1000, device=device)
91+
storage = LazyMemmapStorage(1000)
9292
rb = ReplayBuffer(storage=storage, transform=r3m)
9393

9494
##############################################################################

0 commit comments

Comments
 (0)