Skip to content

Commit 63579d7

Browse files
Improved serialization functions
- now the array backend can get the preferred serialization method - fixed issues in serialization of preprocessors
1 parent edd0bd5 commit 63579d7

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

mushroom_rl/core/array_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ class ArrayBackend(object):
1111
def get_backend_name():
1212
raise NotImplementedError
1313

14+
@staticmethod
15+
def get_backend_serialization():
16+
raise NotImplementedError
17+
1418
@staticmethod
1519
def get_array_backend(backend_name):
1620
assert type(backend_name) == str, f"Backend has to be string, not {type(backend_name).__name__}."
@@ -174,6 +178,10 @@ class NumpyBackend(ArrayBackend):
174178
def get_backend_name():
175179
return 'numpy'
176180

181+
@staticmethod
182+
def get_backend_serialization():
183+
return 'numpy'
184+
177185
@staticmethod
178186
def to_numpy(array):
179187
return array
@@ -303,6 +311,10 @@ class TorchBackend(ArrayBackend):
303311
def get_backend_name():
304312
return 'torch'
305313

314+
@staticmethod
315+
def get_backend_serialization():
316+
return 'torch'
317+
306318
@staticmethod
307319
def to_numpy(array):
308320
return None if array is None else array.detach().cpu().numpy()
@@ -438,6 +450,10 @@ class ListBackend(ArrayBackend):
438450
def get_backend_name():
439451
return 'list'
440452

453+
@staticmethod
454+
def get_backend_serialization():
455+
return 'numpy'
456+
441457
@staticmethod
442458
def to_numpy(array):
443459
return np.array(array)

mushroom_rl/rl_utils/preprocessors.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from mushroom_rl.core import Serializable, ArrayBackend
42
from mushroom_rl.rl_utils.running_stats import RunningStandardization
53

@@ -122,9 +120,9 @@ def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32):
122120
self._add_save_attr(
123121
_array_backend='pickle',
124122
_run_norm_obs='primitive',
125-
_obs_mask='numpy',
126-
_obs_mean='numpy',
127-
_obs_delta='numpy'
123+
_obs_mask=self._array_backend.get_backend_serialization(),
124+
_obs_mean=self._array_backend.get_backend_serialization(),
125+
_obs_delta=self._array_backend.get_backend_serialization()
128126
)
129127

130128
def __call__(self, obs):

0 commit comments

Comments
 (0)