Skip to content

Commit 105e861

Browse files
author
Vincent Moens
authored
[BugFix] Fix storage device (#1650)
1 parent b7d148b commit 105e861

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrl/data/replay_buffers/storages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,9 @@ def set( # noqa: F811
372372
self._init(data)
373373
if not isinstance(cursor, (*INT_CLASSES, slice)):
374374
if not isinstance(cursor, torch.Tensor):
375-
cursor = torch.tensor(cursor)
375+
cursor = torch.tensor(cursor, dtype=torch.long, device=self.device)
376+
elif cursor.dtype != torch.long:
377+
cursor = cursor.to(dtype=torch.long, device=self.device)
376378
if len(cursor) > len(self._storage):
377379
warnings.warn(
378380
"A cursor of length superior to the storage capacity was provided. "

0 commit comments

Comments
 (0)