|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +import heapq |
6 | 7 | from abc import ABC, abstractmethod
|
7 | 8 | from typing import Any, Dict, Sequence
|
8 | 9 |
|
@@ -92,3 +93,128 @@ def extend(self, data: Sequence) -> torch.Tensor:
|
92 | 93 | data["index"] = index
|
93 | 94 | self._storage[index] = data
|
94 | 95 | return index
|
| 96 | + |
| 97 | + |
| 98 | +class TensorDictMaxValueWriter(Writer): |
| 99 | + """A Writer class for composable replay buffers that keeps the top elements based on some ranking key. |
| 100 | +
|
| 101 | + If rank_key is not provided, the key will be ``("next", "reward")``. |
| 102 | +
|
| 103 | + Examples: |
| 104 | + >>> import torch |
| 105 | + >>> from tensordict import TensorDict |
| 106 | + >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter |
| 107 | + >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement |
| 108 | + >>> rb = TensorDictReplayBuffer( |
| 109 | + ... storage=LazyTensorStorage(1), |
| 110 | + ... sampler=SamplerWithoutReplacement(), |
| 111 | + ... batch_size=1, |
| 112 | + ... writer=TensorDictMaxValueWriter(rank_key="key"), |
| 113 | + ... ) |
| 114 | + >>> td = TensorDict({ |
| 115 | + ... "key": torch.tensor(range(10)), |
| 116 | + ... "obs": torch.tensor(range(10)) |
| 117 | + ... }, batch_size=10) |
| 118 | + >>> rb.extend(td) |
| 119 | + >>> print(rb.sample().get("obs").item()) |
| 120 | + 9 |
| 121 | + >>> td = TensorDict({ |
| 122 | + ... "key": torch.tensor(range(10, 20)), |
| 123 | + ... "obs": torch.tensor(range(10, 20)) |
| 124 | + ... }, batch_size=10) |
| 125 | + >>> rb.extend(td) |
| 126 | + >>> print(rb.sample().get("obs").item()) |
| 127 | + 19 |
| 128 | + >>> td = TensorDict({ |
| 129 | + ... "key": torch.tensor(range(10)), |
| 130 | + ... "obs": torch.tensor(range(10)) |
| 131 | + ... }, batch_size=10) |
| 132 | + >>> rb.extend(td) |
| 133 | + >>> print(rb.sample().get("obs").item()) |
| 134 | + 19 |
| 135 | + """ |
| 136 | + |
| 137 | + def __init__(self, rank_key=None, **kwargs) -> None: |
| 138 | + super().__init__(**kwargs) |
| 139 | + self._cursor = 0 |
| 140 | + self._current_top_values = [] |
| 141 | + self._rank_key = rank_key |
| 142 | + if self._rank_key is None: |
| 143 | + self._rank_key = ("next", "reward") |
| 144 | + |
| 145 | + def get_insert_index(self, data: Any) -> int: |
| 146 | + """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" |
| 147 | + if data.batch_dims > 1: |
| 148 | + raise RuntimeError( |
| 149 | + "Expected input tensordict to have no more than 1 dimension, got" |
| 150 | + f"tensordict.batch_size = {data.batch_size}" |
| 151 | + ) |
| 152 | + |
| 153 | + ret = None |
| 154 | + rank_data = data.get(("_data", self._rank_key)) |
| 155 | + |
| 156 | + # If time dimension, sum along it. |
| 157 | + rank_data = rank_data.sum(-1).item() |
| 158 | + |
| 159 | + if rank_data is None: |
| 160 | + raise KeyError(f"Rank key {self._rank_key} not found in data.") |
| 161 | + |
| 162 | + # If the buffer is not full, add the data |
| 163 | + if len(self._current_top_values) < self._storage.max_size: |
| 164 | + |
| 165 | + ret = self._cursor |
| 166 | + self._cursor = (self._cursor + 1) % self._storage.max_size |
| 167 | + |
| 168 | + # Add new reward to the heap |
| 169 | + heapq.heappush(self._current_top_values, (rank_data, ret)) |
| 170 | + |
| 171 | + # If the buffer is full, check if the new data is better than the worst data in the buffer |
| 172 | + elif rank_data > self._current_top_values[0][0]: |
| 173 | + |
| 174 | + # retrieve position of the smallest value |
| 175 | + min_sample = heapq.heappop(self._current_top_values) |
| 176 | + ret = min_sample[1] |
| 177 | + |
| 178 | + # Add new reward to the heap |
| 179 | + heapq.heappush(self._current_top_values, (rank_data, ret)) |
| 180 | + |
| 181 | + return ret |
| 182 | + |
| 183 | + def add(self, data: Any) -> int: |
| 184 | + """Inserts a single element of data at an appropriate index, and returns that index. |
| 185 | +
|
| 186 | + The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where |
| 187 | + :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed |
| 188 | + over the time dimension. |
| 189 | + """ |
| 190 | + index = self.get_insert_index(data) |
| 191 | + if index is not None: |
| 192 | + data.set("index", index) |
| 193 | + self._storage[index] = data |
| 194 | + return index |
| 195 | + |
| 196 | + def extend(self, data: Sequence) -> None: |
| 197 | + """Inserts a series of data points at appropriate indices. |
| 198 | +
|
| 199 | + The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is |
| 200 | + the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the |
| 201 | + time dimension. |
| 202 | + """ |
| 203 | + data_to_replace = {} |
| 204 | + for i, sample in enumerate(data): |
| 205 | + index = self.get_insert_index(sample) |
| 206 | + if index is not None: |
| 207 | + data_to_replace[index] = i |
| 208 | + |
| 209 | + # Replace the data in the storage all at once |
| 210 | + keys, values = zip(*data_to_replace.items()) |
| 211 | + if len(keys) > 0: |
| 212 | + index = data.get("index") |
| 213 | + values = list(values) |
| 214 | + keys = index[values] = torch.tensor(keys, dtype=index.dtype) |
| 215 | + data.set("index", index) |
| 216 | + self._storage[keys] = data[values] |
| 217 | + |
| 218 | + def _empty(self) -> None: |
| 219 | + self._cursor = 0 |
| 220 | + self._current_top_values = [] |
0 commit comments