Skip to content

Commit c0e8a1c

Browse files
authored
[BugFix] Propagate args to TransformedEnv's state_dict (#944)
1 parent b14734a commit c0e8a1c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
523523
out_tensordict = self.transform._call(out_tensordict)
524524
return out_tensordict
525525

526-
def state_dict(self) -> OrderedDict:
527-
state_dict = self.transform.state_dict()
526+
def state_dict(self, *args, **kwargs) -> OrderedDict:
527+
state_dict = self.transform.state_dict(*args, **kwargs)
528528
return state_dict
529529

530530
def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:

0 commit comments

Comments
 (0)