Skip to content

Commit 55d667e

Browse files
albertbou92vmoens
andauthored
[Feature] Max Value Writer (#1622)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 38dfc21 commit 55d667e

File tree

6 files changed

+203
-7
lines changed

6 files changed

+203
-7
lines changed

docs/source/reference/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co
4343
Writer
4444
RoundRobinWriter
4545
TensorDictRoundRobinWriter
46+
TensorDictMaxValueWriter
4647

4748
Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
4849
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.

test/test_rb.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
ListStorage,
3939
TensorStorage,
4040
)
41-
from torchrl.data.replay_buffers.writers import RoundRobinWriter
41+
from torchrl.data.replay_buffers.writers import (
42+
RoundRobinWriter,
43+
TensorDictMaxValueWriter,
44+
)
4245
from torchrl.envs.transforms.transforms import (
4346
BinarizeReward,
4447
CatFrames,
@@ -1209,6 +1212,65 @@ def test_load_state_dict(self, storage_in, storage_out, init_out):
12091212
assert (s.exclude("index") == 1).all()
12101213

12111214

1215+
@pytest.mark.parametrize("size", [20, 25, 30])
1216+
@pytest.mark.parametrize("batch_size", [1, 10, 15])
1217+
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
1218+
def test_max_value_writer(size, batch_size, reward_ranges):
1219+
rb = TensorDictReplayBuffer(
1220+
storage=LazyTensorStorage(size),
1221+
sampler=SamplerWithoutReplacement(),
1222+
batch_size=batch_size,
1223+
writer=TensorDictMaxValueWriter(rank_key="key"),
1224+
)
1225+
1226+
max_reward1, max_reward2, max_reward3 = reward_ranges
1227+
1228+
td = TensorDict(
1229+
{
1230+
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
1231+
"obs": torch.tensor(torch.rand(size)),
1232+
},
1233+
batch_size=size,
1234+
device="cpu",
1235+
)
1236+
rb.extend(td)
1237+
sample = rb.sample()
1238+
assert (sample.get("key") <= max_reward1).all()
1239+
assert (0 <= sample.get("key")).all()
1240+
assert len(sample.get("index").unique()) == len(sample.get("index"))
1241+
1242+
td = TensorDict(
1243+
{
1244+
"key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2),
1245+
"obs": torch.tensor(torch.rand(size)),
1246+
},
1247+
batch_size=size,
1248+
device="cpu",
1249+
)
1250+
rb.extend(td)
1251+
sample = rb.sample()
1252+
assert (sample.get("key") <= max_reward2).all()
1253+
assert (max_reward1 <= sample.get("key")).all()
1254+
assert len(sample.get("index").unique()) == len(sample.get("index"))
1255+
1256+
td = TensorDict(
1257+
{
1258+
"key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3),
1259+
"obs": torch.tensor(torch.rand(size)),
1260+
},
1261+
batch_size=size,
1262+
device="cpu",
1263+
)
1264+
1265+
for sample in td:
1266+
rb.add(sample)
1267+
1268+
sample = rb.sample()
1269+
assert (sample.get("key") <= max_reward3).all()
1270+
assert (max_reward2 <= sample.get("key")).all()
1271+
assert len(sample.get("index").unique()) == len(sample.get("index"))
1272+
1273+
12121274
if __name__ == "__main__":
12131275
args, unknown = argparse.ArgumentParser().parse_known_args()
12141276
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ReplayBuffer,
1515
RoundRobinWriter,
1616
Storage,
17+
TensorDictMaxValueWriter,
1718
TensorDictPrioritizedReplayBuffer,
1819
TensorDictReplayBuffer,
1920
TensorDictRoundRobinWriter,

torchrl/data/replay_buffers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,9 @@
2323
Storage,
2424
TensorStorage,
2525
)
26-
from .writers import RoundRobinWriter, TensorDictRoundRobinWriter, Writer
26+
from .writers import (
27+
RoundRobinWriter,
28+
TensorDictMaxValueWriter,
29+
TensorDictRoundRobinWriter,
30+
Writer,
31+
)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,13 @@ def add(self, data: TensorDictBase) -> int:
718718
data_add = data
719719

720720
index = super()._add(data_add)
721-
if is_tensor_collection(data_add):
722-
data_add.set("index", index)
721+
if index is not None:
722+
if is_tensor_collection(data_add):
723+
data_add.set("index", index)
723724

724-
# priority = self._get_priority(data)
725-
# if priority:
726-
self.update_tensordict_priority(data_add)
725+
# priority = self._get_priority(data)
726+
# if priority:
727+
self.update_tensordict_priority(data_add)
727728
return index
728729

729730
def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:

torchrl/data/replay_buffers/writers.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import heapq
67
from abc import ABC, abstractmethod
78
from typing import Any, Dict, Sequence
89

@@ -92,3 +93,128 @@ def extend(self, data: Sequence) -> torch.Tensor:
9293
data["index"] = index
9394
self._storage[index] = data
9495
return index
96+
97+
98+
class TensorDictMaxValueWriter(Writer):
99+
"""A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
100+
101+
If rank_key is not provided, the key will be ``("next", "reward")``.
102+
103+
Examples:
104+
>>> import torch
105+
>>> from tensordict import TensorDict
106+
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
107+
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
108+
>>> rb = TensorDictReplayBuffer(
109+
... storage=LazyTensorStorage(1),
110+
... sampler=SamplerWithoutReplacement(),
111+
... batch_size=1,
112+
... writer=TensorDictMaxValueWriter(rank_key="key"),
113+
... )
114+
>>> td = TensorDict({
115+
... "key": torch.tensor(range(10)),
116+
... "obs": torch.tensor(range(10))
117+
... }, batch_size=10)
118+
>>> rb.extend(td)
119+
>>> print(rb.sample().get("obs").item())
120+
9
121+
>>> td = TensorDict({
122+
... "key": torch.tensor(range(10, 20)),
123+
... "obs": torch.tensor(range(10, 20))
124+
... }, batch_size=10)
125+
>>> rb.extend(td)
126+
>>> print(rb.sample().get("obs").item())
127+
19
128+
>>> td = TensorDict({
129+
... "key": torch.tensor(range(10)),
130+
... "obs": torch.tensor(range(10))
131+
... }, batch_size=10)
132+
>>> rb.extend(td)
133+
>>> print(rb.sample().get("obs").item())
134+
19
135+
"""
136+
137+
def __init__(self, rank_key=None, **kwargs) -> None:
138+
super().__init__(**kwargs)
139+
self._cursor = 0
140+
self._current_top_values = []
141+
self._rank_key = rank_key
142+
if self._rank_key is None:
143+
self._rank_key = ("next", "reward")
144+
145+
def get_insert_index(self, data: Any) -> int:
146+
"""Returns the index where the data should be inserted, or ``None`` if it should not be inserted."""
147+
if data.batch_dims > 1:
148+
raise RuntimeError(
149+
"Expected input tensordict to have no more than 1 dimension, got"
150+
f"tensordict.batch_size = {data.batch_size}"
151+
)
152+
153+
ret = None
154+
rank_data = data.get(("_data", self._rank_key))
155+
156+
# If time dimension, sum along it.
157+
rank_data = rank_data.sum(-1).item()
158+
159+
if rank_data is None:
160+
raise KeyError(f"Rank key {self._rank_key} not found in data.")
161+
162+
# If the buffer is not full, add the data
163+
if len(self._current_top_values) < self._storage.max_size:
164+
165+
ret = self._cursor
166+
self._cursor = (self._cursor + 1) % self._storage.max_size
167+
168+
# Add new reward to the heap
169+
heapq.heappush(self._current_top_values, (rank_data, ret))
170+
171+
# If the buffer is full, check if the new data is better than the worst data in the buffer
172+
elif rank_data > self._current_top_values[0][0]:
173+
174+
# retrieve position of the smallest value
175+
min_sample = heapq.heappop(self._current_top_values)
176+
ret = min_sample[1]
177+
178+
# Add new reward to the heap
179+
heapq.heappush(self._current_top_values, (rank_data, ret))
180+
181+
return ret
182+
183+
def add(self, data: Any) -> int:
184+
"""Inserts a single element of data at an appropriate index, and returns that index.
185+
186+
The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where
187+
:obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed
188+
over the time dimension.
189+
"""
190+
index = self.get_insert_index(data)
191+
if index is not None:
192+
data.set("index", index)
193+
self._storage[index] = data
194+
return index
195+
196+
def extend(self, data: Sequence) -> None:
197+
"""Inserts a series of data points at appropriate indices.
198+
199+
The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is
200+
the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the
201+
time dimension.
202+
"""
203+
data_to_replace = {}
204+
for i, sample in enumerate(data):
205+
index = self.get_insert_index(sample)
206+
if index is not None:
207+
data_to_replace[index] = i
208+
209+
# Replace the data in the storage all at once
210+
keys, values = zip(*data_to_replace.items())
211+
if len(keys) > 0:
212+
index = data.get("index")
213+
values = list(values)
214+
keys = index[values] = torch.tensor(keys, dtype=index.dtype)
215+
data.set("index", index)
216+
self._storage[keys] = data[values]
217+
218+
def _empty(self) -> None:
219+
self._cursor = 0
220+
self._current_top_values = []

0 commit comments

Comments
 (0)