We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b7d148b commit 105e861Copy full SHA for 105e861
torchrl/data/replay_buffers/storages.py
@@ -372,7 +372,9 @@ def set( # noqa: F811
372
self._init(data)
373
if not isinstance(cursor, (*INT_CLASSES, slice)):
374
if not isinstance(cursor, torch.Tensor):
375
- cursor = torch.tensor(cursor)
+ cursor = torch.tensor(cursor, dtype=torch.long, device=self.device)
376
+ elif cursor.dtype != torch.long:
377
+ cursor = cursor.to(dtype=torch.long, device=self.device)
378
if len(cursor) > len(self._storage):
379
warnings.warn(
380
"A cursor of length superior to the storage capacity was provided. "
0 commit comments