Skip to content

Commit f0cda31

Browse files
author
Vincent Moens
committed
[BugFix] Fix PRB serialization
ghstack-source-id: a40d39a Pull-Request-resolved: #2963
1 parent f121f4d commit f0cda31

File tree

2 files changed

+60
-8
lines changed

2 files changed

+60
-8
lines changed

test/test_rb.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,6 +2997,49 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
29972997
assert rb._sampler._max_priority[0] == 21
29982998
assert rb._sampler._max_priority[1] == 0
29992999

3000+
def test_prb_serialization(self, tmpdir):
3001+
rb = ReplayBuffer(
3002+
storage=LazyMemmapStorage(max_size=10),
3003+
sampler=PrioritizedSampler(max_capacity=10, alpha=0.8, beta=0.6),
3004+
)
3005+
3006+
td = TensorDict(
3007+
{
3008+
"observations": torch.zeros(1, 3),
3009+
"actions": torch.zeros(1, 1),
3010+
"rewards": torch.zeros(1, 1),
3011+
"next_observations": torch.zeros(1, 3),
3012+
"terminations": torch.zeros(1, 1, dtype=torch.bool),
3013+
},
3014+
batch_size=[1],
3015+
)
3016+
rb.extend(td)
3017+
3018+
rb.save(tmpdir)
3019+
3020+
rb2 = ReplayBuffer(
3021+
storage=LazyMemmapStorage(max_size=10),
3022+
sampler=PrioritizedSampler(max_capacity=10, alpha=0.5, beta=0.5),
3023+
)
3024+
3025+
td = TensorDict(
3026+
{
3027+
"observations": torch.ones(1, 3),
3028+
"actions": torch.ones(1, 1),
3029+
"rewards": torch.ones(1, 1),
3030+
"next_observations": torch.ones(1, 3),
3031+
"terminations": torch.ones(1, 1, dtype=torch.bool),
3032+
},
3033+
batch_size=[1],
3034+
)
3035+
rb2.extend(td)
3036+
rb2.load(tmpdir)
3037+
assert len(rb) == 1
3038+
assert rb.sampler._alpha == rb2.sampler._alpha
3039+
assert rb.sampler._beta == rb2.sampler._beta
3040+
assert rb.sampler._max_priority[0] == rb2.sampler._max_priority[0]
3041+
assert rb.sampler._max_priority[1] == rb2.sampler._max_priority[1]
3042+
30003043
def test_prb_ndim(self):
30013044
"""This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB.
30023045

torchrl/data/replay_buffers/samplers.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from tensordict import MemoryMappedTensor, TensorDict
2020
from tensordict.utils import NestedKey
21+
22+
from torch.utils._pytree import tree_map
2123
from torchrl._extension import EXTENSION_WARNING
2224
from torchrl._utils import _replace_last, logger
2325
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
@@ -676,13 +678,16 @@ def dumps(self, path):
676678
)
677679
with open(path / "sampler_metadata.json", "w") as file:
678680
json.dump(
679-
{
680-
"_alpha": self._alpha,
681-
"_beta": self._beta,
682-
"_eps": self._eps,
683-
"_max_priority": self._max_priority,
684-
"_max_capacity": self._max_capacity,
685-
},
681+
tree_map(
682+
float,
683+
{
684+
"_alpha": self._alpha,
685+
"_beta": self._beta,
686+
"_eps": self._eps,
687+
"_max_priority": self._max_priority,
688+
"_max_capacity": self._max_capacity,
689+
},
690+
),
686691
file,
687692
)
688693

@@ -693,7 +698,11 @@ def loads(self, path):
693698
self._alpha = metadata["_alpha"]
694699
self._beta = metadata["_beta"]
695700
self._eps = metadata["_eps"]
696-
self._max_priority = metadata["_max_priority"]
701+
tree_map(
702+
lambda dest, orig: dest.copy_(orig),
703+
tuple(self._max_priority),
704+
tuple(metadata["_max_priority"]),
705+
)
697706
_max_capacity = metadata["_max_capacity"]
698707
if _max_capacity != self._max_capacity:
699708
raise RuntimeError(

0 commit comments

Comments
 (0)