Skip to content

Commit c5f5cda

Browse files
authored
[BugFix]: Tensor map for subtensordict.set_ (#324)
1 parent 32bf860 commit c5f5cda

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,11 @@ def extend(
670670
else:
671671
stacked_td = tensordicts
672672
idx = super().extend(tensordicts, priorities)
673-
stacked_td.set("index", idx, inplace=True)
673+
stacked_td.set(
674+
"index",
675+
torch.tensor(idx, dtype=torch.int, device=stacked_td.device),
676+
inplace=True,
677+
)
674678
return idx
675679

676680
def update_priority(self, tensordict: TensorDictBase) -> None:

torchrl/data/tensordict/tensordict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2766,6 +2766,9 @@ def set_(
27662766
f"tensor.shape={tensor.shape[:self.batch_dims]} and "
27672767
f"self.batch_size={self.batch_size} mismatch"
27682768
)
2769+
tensor = self._process_tensor(
2770+
tensor, check_device=False, check_tensor_shape=False
2771+
)
27692772
self._source.set_at_(key, tensor, self.idx)
27702773
if key in self._dict_meta:
27712774
self._dict_meta[key].requires_grad = tensor.requires_grad

0 commit comments

Comments
 (0)