Skip to content

Commit e8e511d

Browse files
[BugFix] Improve done checking of collectors (#838)
1 parent 2d1723c commit e8e511d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchrl/collectors/collectors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,14 +614,18 @@ def _reset_if_necessary(self) -> None:
614614
steps = steps.clone()
615615
if len(self.env.batch_size):
616616
self._tensordict.masked_fill_(done_or_terminated, 0)
617-
self._tensordict.set("_reset", done_or_terminated)
617+
_reset = done_or_terminated
618+
self._tensordict.set("_reset", _reset)
618619
else:
620+
_reset = None
619621
self._tensordict.zero_()
620622
self.env.reset(self._tensordict)
621623

622-
if self._tensordict.get("done").any():
624+
if (_reset is None and self._tensordict.get("done").any()) or (
625+
_reset is not None and self._tensordict.get("done")[_reset].any()
626+
):
623627
raise RuntimeError(
624-
f"Got {sum(self._tensordict.get('done'))} done envs after reset."
628+
f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed."
625629
)
626630
traj_ids[done_or_terminated] = traj_ids.max() + torch.arange(
627631
1, done_or_terminated.sum() + 1, device=traj_ids.device

0 commit comments

Comments
 (0)