Skip to content

Commit 95b7206

Browse files
authored
[BugFix] Make VecNorm Transform pickable (#1596)
1 parent 821d8bc commit 95b7206

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

test/test_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77

88
import itertools
9+
import pickle
910
import sys
1011
from copy import copy
1112
from functools import partial
@@ -7327,6 +7328,15 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200):
73277328
env_t.close()
73287329
self.SEED = 0
73297330

7331+
def test_pickable(self):
7332+
7333+
transform = VecNorm()
7334+
serialized = pickle.dumps(transform)
7335+
transform2 = pickle.loads(serialized)
7336+
assert transform.__dict__.keys() == transform2.__dict__.keys()
7337+
for key in sorted(transform.__dict__.keys()):
7338+
assert isinstance(transform.__dict__[key], type(transform2.__dict__[key]))
7339+
73307340

73317341
def test_added_transforms_are_in_eval_mode_trivial():
73327342
base_env = ContinuousActionVecMockEnv()

torchrl/envs/transforms/transforms.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from copy import copy
1212
from functools import wraps
1313
from textwrap import indent
14-
from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union
14+
from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Union
1515

1616
import numpy as np
1717

@@ -4337,6 +4337,20 @@ def __repr__(self) -> str:
43374337
f"eps={self.eps:4.4f}, keys={self.in_keys})"
43384338
)
43394339

4340+
def __getstate__(self) -> Dict[str, Any]:
4341+
state = self.__dict__.copy()
4342+
_lock = state.pop("lock", None)
4343+
if _lock is not None:
4344+
state["lock_placeholder"] = None
4345+
return state
4346+
4347+
def __setstate__(self, state: Dict[str, Any]):
4348+
if "lock_placeholder" in state:
4349+
state.pop("lock_placeholder")
4350+
_lock = mp.Lock()
4351+
state["lock"] = _lock
4352+
self.__dict__.update(state)
4353+
43404354

43414355
class RewardSum(Transform):
43424356
"""Tracks episode cumulative rewards.

0 commit comments

Comments
 (0)