Skip to content

Commit 501b3af

Browse files
authored
[Quality] RB constuctors cleanup (#945)
1 parent 732e3a2 commit 501b3af

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

test/test_rb.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
@pytest.mark.parametrize("writer", [writers.RoundRobinWriter])
7676
@pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage])
7777
@pytest.mark.parametrize("size", [3, 5, 100])
78-
class TestPrototypeBuffers:
78+
class TestComposableBuffers:
7979
def _get_rb(self, rb_type, size, sampler, writer, storage):
8080

8181
if storage is not None:
@@ -884,6 +884,67 @@ def test_samplerwithoutrep(size, samples, drop_last):
884884
assert not visited
885885

886886

887+
class TestStateDict:
888+
@pytest.mark.parametrize("storage_in", ["tensor", "memmap"])
889+
@pytest.mark.parametrize("storage_out", ["tensor", "memmap"])
890+
@pytest.mark.parametrize("init_out", [True, False])
891+
def test_load_state_dict(self, storage_in, storage_out, init_out):
892+
buffer_size = 100
893+
if storage_in == "memmap":
894+
storage_in = LazyMemmapStorage(
895+
buffer_size,
896+
device="cpu",
897+
)
898+
elif storage_in == "tensor":
899+
storage_in = LazyTensorStorage(
900+
buffer_size,
901+
device="cpu",
902+
)
903+
if storage_out == "memmap":
904+
storage_out = LazyMemmapStorage(
905+
buffer_size,
906+
device="cpu",
907+
)
908+
elif storage_out == "tensor":
909+
storage_out = LazyTensorStorage(
910+
buffer_size,
911+
device="cpu",
912+
)
913+
914+
replay_buffer = TensorDictReplayBuffer(
915+
pin_memory=False,
916+
prefetch=3,
917+
storage=storage_in,
918+
)
919+
# fill replay buffer with random data
920+
transition = TensorDict(
921+
{
922+
"observation": torch.ones(1, 4),
923+
"action": torch.ones(1, 2),
924+
"reward": torch.ones(1, 1),
925+
"dones": torch.ones(1, 1),
926+
"next": {"observation": torch.ones(1, 4)},
927+
},
928+
batch_size=1,
929+
)
930+
for _ in range(3):
931+
replay_buffer.extend(transition)
932+
933+
state_dict = replay_buffer.state_dict()
934+
935+
new_replay_buffer = TensorDictReplayBuffer(
936+
pin_memory=False,
937+
prefetch=3,
938+
storage=storage_out,
939+
)
940+
if init_out:
941+
new_replay_buffer.extend(transition)
942+
943+
new_replay_buffer.load_state_dict(state_dict)
944+
s = new_replay_buffer.sample(3)
945+
assert (s.exclude("index") == 1).all()
946+
947+
887948
if __name__ == "__main__":
888949
args, unknown = argparse.ArgumentParser().parse_known_args()
889950
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ class TensorDictReplayBuffer(ReplayBuffer):
341341
within TensorDicts added to this ReplayBuffer.
342342
"""
343343

344-
def __init__(self, priority_key: str = "td_error", **kw) -> None:
345-
super().__init__(**kw)
344+
def __init__(self, *args, priority_key: str = "td_error", **kw) -> None:
345+
super().__init__(*args, **kw)
346346
self.priority_key = priority_key
347347

348348
def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:

torchrl/data/replay_buffers/storages.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from tensordict.memmap import MemmapTensor
1414
from tensordict.prototype import is_tensorclass
15-
from tensordict.tensordict import TensorDict, TensorDictBase
15+
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
1616

1717
from torchrl._utils import _CKPT_BACKEND
1818
from torchrl.data.replay_buffers.utils import INT_CLASSES
@@ -210,11 +210,7 @@ def load_state_dict(self, state_dict):
210210
if isinstance(self._storage, TensorDictBase):
211211
self._storage.load_state_dict(_storage)
212212
elif self._storage is None:
213-
batch_size = _storage.pop("__batch_size")
214-
device = _storage.pop("__device")
215-
self._storage = TensorDict(
216-
_storage, batch_size=batch_size, device=device
217-
)
213+
self._storage = TensorDict({}, []).load_state_dict(_storage)
218214
else:
219215
raise RuntimeError(
220216
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
@@ -333,15 +329,11 @@ def load_state_dict(self, state_dict):
333329
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
334330
)
335331
elif isinstance(_storage, (dict, OrderedDict)):
336-
if isinstance(self._storage, TensorDictBase):
332+
if is_tensor_collection(self._storage):
337333
self._storage.load_state_dict(_storage)
338334
self._storage.memmap_()
339335
elif self._storage is None:
340-
batch_size = _storage.pop("__batch_size")
341-
device = _storage.pop("__device")
342-
self._storage = TensorDict(
343-
_storage, batch_size=batch_size, device=device
344-
)
336+
self._storage = TensorDict({}, []).load_state_dict(_storage)
345337
self._storage.memmap_()
346338
else:
347339
raise RuntimeError(

0 commit comments

Comments
 (0)