Skip to content

Commit 530dac3

Browse files
[BugFix] correct the use of step_mdp method in data collector (#637)
* rectify step_mdp usage * run linting * run linting Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 4e1b878 commit 530dac3

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

torchrl/collectors/collectors.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,12 +633,7 @@ def rollout(self) -> TensorDictBase:
633633
tensordict_out.append(self._tensordict.clone())
634634

635635
self._reset_if_necessary()
636-
self._tensordict.update(
637-
step_mdp(
638-
self._tensordict.exclude("reward", "done"), keep_other=True
639-
),
640-
inplace=True,
641-
)
636+
self._tensordict.update(step_mdp(self._tensordict), inplace=True)
642637
if self.return_in_place and len(self._tensordict_out.keys()) > 0:
643638
tensordict_out = torch.stack(tensordict_out, len(self.env.batch_size))
644639
tensordict_out = tensordict_out.select(*self._tensordict_out.keys())

0 commit comments

Comments
 (0)