Skip to content

Commit d93551d

Browse files
author
Vincent Moens
authored
[BugFix] make cursor a torch.long tensor (#1639)
1 parent c7d4764 commit d93551d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrl/data/replay_buffers/storages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ def set( # noqa: F811
340340
self._init(data)
341341
if not isinstance(cursor, (*INT_CLASSES, slice)):
342342
if not isinstance(cursor, torch.Tensor):
343-
cursor = torch.tensor(cursor)
343+
cursor = torch.tensor(cursor, dtype=torch.long)
344+
elif cursor.dtype != torch.long:
345+
cursor = cursor.to(dtype=torch.long)
344346
if len(cursor) > len(self._storage):
345347
warnings.warn(
346348
"A cursor of length superior to the storage capacity was provided. "

0 commit comments

Comments
 (0)