Skip to content

Commit 3d2c161

Browse files
[BugFix] Vectorized priority update in replay buffers (#1598)
Signed-off-by: Matteo Bettini <matbet@meta.com> Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 146af04 commit 3d2c161

File tree

1 file changed

+62
-50
lines changed

1 file changed

+62
-50
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def __init__(self, *, priority_key: str = "td_error", **kw) -> None:
662662
super().__init__(**kw)
663663
self.priority_key = priority_key
664664

665-
def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
665+
def _get_priority_item(self, tensordict: TensorDictBase) -> float:
666666
if "_data" in tensordict.keys():
667667
tensordict = tensordict.get("_data")
668668

@@ -682,6 +682,23 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
682682
)
683683
return priority
684684

685+
def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
686+
if "_data" in tensordict.keys():
687+
tensordict = tensordict.get("_data")
688+
689+
priority = tensordict.get(self.priority_key, None)
690+
if priority is None:
691+
return torch.tensor(
692+
self._sampler.default_priority,
693+
dtype=torch.float,
694+
device=tensordict.device,
695+
).expand(tensordict.shape[0])
696+
697+
priority = priority.reshape(priority.shape[0], -1)
698+
priority = _reduce(priority, self._sampler.reduction, dim=1)
699+
700+
return priority
701+
685702
def add(self, data: TensorDictBase) -> int:
686703
if self._transform is not None:
687704
data = self._transform.inv(data)
@@ -709,61 +726,50 @@ def add(self, data: TensorDictBase) -> int:
709726
self.update_tensordict_priority(data_add)
710727
return index
711728

712-
def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
713-
if is_tensor_collection(tensordicts):
714-
tensordicts = TensorDict(
715-
{"_data": tensordicts},
716-
batch_size=tensordicts.batch_size[:1],
717-
)
718-
if tensordicts.batch_dims > 1:
719-
# we want the tensordict to have one dimension only. The batch size
720-
# of the sampled tensordicts can be changed thereafter
721-
if not isinstance(tensordicts, LazyStackedTensorDict):
722-
tensordicts = tensordicts.clone(recurse=False)
723-
else:
724-
tensordicts = tensordicts.contiguous()
725-
# we keep track of the batch size to reinstantiate it when sampling
726-
if "_rb_batch_size" in tensordicts.keys():
727-
raise KeyError(
728-
"conflicting key '_rb_batch_size'. Consider removing from data."
729-
)
730-
shape = torch.tensor(tensordicts.batch_size[1:]).expand(
731-
tensordicts.batch_size[0], tensordicts.batch_dims - 1
729+
def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
730+
731+
tensordicts = TensorDict(
732+
{"_data": tensordicts},
733+
batch_size=tensordicts.batch_size[:1],
734+
)
735+
if tensordicts.batch_dims > 1:
736+
# we want the tensordict to have one dimension only. The batch size
737+
# of the sampled tensordicts can be changed thereafter
738+
if not isinstance(tensordicts, LazyStackedTensorDict):
739+
tensordicts = tensordicts.clone(recurse=False)
740+
else:
741+
tensordicts = tensordicts.contiguous()
742+
# we keep track of the batch size to reinstantiate it when sampling
743+
if "_rb_batch_size" in tensordicts.keys():
744+
raise KeyError(
745+
"conflicting key '_rb_batch_size'. Consider removing from data."
732746
)
733-
tensordicts.set("_rb_batch_size", shape)
734-
tensordicts.set(
735-
"index",
736-
torch.zeros(
737-
tensordicts.shape, device=tensordicts.device, dtype=torch.int
738-
),
747+
shape = torch.tensor(tensordicts.batch_size[1:]).expand(
748+
tensordicts.batch_size[0], tensordicts.batch_dims - 1
739749
)
740-
741-
if not is_tensor_collection(tensordicts):
742-
stacked_td = torch.stack(tensordicts, 0)
743-
else:
744-
stacked_td = tensordicts
750+
tensordicts.set("_rb_batch_size", shape)
751+
tensordicts.set(
752+
"index",
753+
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
754+
)
745755

746756
if self._transform is not None:
747-
tensordicts = self._transform.inv(stacked_td.get("_data"))
748-
stacked_td.set("_data", tensordicts)
749-
if tensordicts.device is not None:
750-
stacked_td = stacked_td.to(tensordicts.device)
757+
data = self._transform.inv(tensordicts.get("_data"))
758+
tensordicts.set("_data", data)
759+
if data.device is not None:
760+
tensordicts = tensordicts.to(data.device)
751761

752-
index = super()._extend(stacked_td)
753-
self.update_tensordict_priority(stacked_td)
762+
index = super()._extend(tensordicts)
763+
self.update_tensordict_priority(tensordicts)
754764
return index
755765

756766
def update_tensordict_priority(self, data: TensorDictBase) -> None:
757767
if not isinstance(self._sampler, PrioritizedSampler):
758768
return
759769
if data.ndim:
760-
priority = torch.tensor(
761-
[self._get_priority(td) for td in data],
762-
dtype=torch.float,
763-
device=data.device,
764-
)
770+
priority = self._get_priority_vector(data)
765771
else:
766-
priority = self._get_priority(data)
772+
priority = self._get_priority_item(data)
767773
index = data.get("index")
768774
while index.shape != priority.shape:
769775
# reduce index
@@ -1010,17 +1016,23 @@ def __call__(self, list_of_tds):
10101016
return self.out
10111017

10121018

1013-
def _reduce(tensor: torch.Tensor, reduction: str):
1019+
def _reduce(
1020+
tensor: torch.Tensor, reduction: str, dim: Optional[int] = None
1021+
) -> Union[float, torch.Tensor]:
10141022
"""Reduces a tensor given the reduction method."""
10151023
if reduction == "max":
1016-
return tensor.max().item()
1024+
result = tensor.max(dim=dim)
10171025
elif reduction == "min":
1018-
return tensor.min().item()
1026+
result = tensor.min(dim=dim)
10191027
elif reduction == "mean":
1020-
return tensor.mean().item()
1028+
result = tensor.mean(dim=dim)
10211029
elif reduction == "median":
1022-
return tensor.median().item()
1023-
raise NotImplementedError(f"Unknown reduction method {reduction}")
1030+
result = tensor.median(dim=dim)
1031+
else:
1032+
raise NotImplementedError(f"Unknown reduction method {reduction}")
1033+
if isinstance(result, tuple):
1034+
result = result[0]
1035+
return result.item() if dim is None else result
10241036

10251037

10261038
def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:

0 commit comments

Comments
 (0)