Skip to content

Commit ea76ffb

Browse files
author
Vincent Moens
committed
[BugFix] Fix calls to _reset_env_preprocess
ghstack-source-id: 5992563 Pull Request resolved: #2798
1 parent dd59290 commit ea76ffb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def _reset(
322322

323323
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
324324
"""Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
325-
if self.enable_inv_on_reset:
325+
if self.enable_inv_on_reset and tensordict is not None:
326326
with _set_missing_tolerance(self, True):
327-
tensordict = self.inv(tensordict)
327+
tensordict = self._inv_call(tensordict)
328328
return tensordict
329329

330330
def init(self, tensordict) -> None:
@@ -1166,6 +1166,9 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
11661166
tensordict = tensordict.select(
11671167
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
11681168
)
1169+
# We always call _reset_env_preprocess, even if tensordict is None - that way one can augment that
1170+
# method to do any pre-reset operation.
1171+
# By default, within _reset_env_preprocess we will skip the inv call when tensordict is None.
11691172
tensordict = self.transform._reset_env_preprocess(tensordict)
11701173
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
11711174
if tensordict is None:

0 commit comments

Comments
 (0)