Skip to content

Commit 2e7def8

Browse files
authored
[BugFix] Fix deprecated list index (#3005)
1 parent 0355a01 commit 2e7def8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3599,7 +3599,7 @@ def _call(self, next_tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
35993599
n = buffer_reset.ndimension() + self.dim
36003600
else:
36013601
raise ValueError(self._CAT_DIM_ERR)
3602-
idx = [slice(None, None) for _ in range(n)] + [slice(-d, None)]
3602+
idx = tuple([slice(None, None) for _ in range(n)] + [slice(-d, None)])
36033603
if not _all:
36043604
buffer_reset = buffer[_reset]
36053605
buffer_reset[idx] = data_reset
@@ -3612,7 +3612,7 @@ def _call(self, next_tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
36123612
n = buffer.ndimension() + self.dim
36133613
else:
36143614
raise ValueError(self._CAT_DIM_ERR)
3615-
idx = [slice(None, None) for _ in range(n)] + [slice(-d, None)]
3615+
idx = tuple([slice(None, None) for _ in range(n)] + [slice(-d, None)])
36163616
buffer[idx] = buffer[idx].copy_(data)
36173617
# add to tensordict
36183618
next_tensordict.set(out_key, buffer.clone())

0 commit comments

Comments
 (0)