Skip to content

Commit 5c3834d

Browse files
authored
Add infoclass (#153)
* Fix calculation of average Return * Add required properties * Add cloning of observation to prevent bug * Add ExtraInfo * Fix Bug in Multiprocess Environment * Reset_all now returns real info dictionary * Add assert to prevent None Environment
1 parent 33c8779 commit 5c3834d

File tree

10 files changed

+1001
-31
lines changed

10 files changed

+1001
-31
lines changed

examples/isaac_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep
7979

8080
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
8181

82-
J = torch.mean(torch.stack(dataset.discounted_return))
83-
R = torch.mean(torch.stack(dataset.undiscounted_return))
82+
J = torch.mean(dataset.discounted_return)
83+
R = torch.mean(dataset.undiscounted_return)
8484
E = agent.policy.entropy()
8585

8686
logger.epoch_info(0, J=J, R=R, entropy=E)
@@ -89,8 +89,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep
8989
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
9090
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
9191

92-
J = torch.mean(torch.stack(dataset.discounted_return))
93-
R = torch.mean(torch.stack(dataset.undiscounted_return))
92+
J = torch.mean(dataset.discounted_return)
93+
R = torch.mean(dataset.undiscounted_return)
9494
E = agent.policy.entropy()
9595

9696
logger.epoch_info(it+1, J=J, R=R, entropy=E)

mushroom_rl/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from .serialization import Serializable
77
from .logger import Logger
88

9+
from .extra_info import ExtraInfo
10+
911
from .vectorized_core import VectorCore
1012
from .vectorized_env import VectorizedEnvironment
1113
from .multiprocess_environment import MultiprocessEnvironment
1214

1315
import mushroom_rl.environments
1416

1517
__all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo',
16-
'Serializable', 'Logger', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
18+
'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']

mushroom_rl/core/array_backend.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ def from_list(array):
147147
@staticmethod
148148
def pack_padded_sequence(array, mask):
149149
raise NotImplementedError
150+
151+
@staticmethod
152+
def flatten(array):
153+
raise NotImplementedError
154+
155+
@staticmethod
156+
def empty(shape, device=None):
157+
raise NotImplementedError
158+
159+
@staticmethod
160+
def none():
161+
raise NotImplementedError
162+
163+
@staticmethod
164+
def shape(array):
165+
raise NotImplementedError
166+
167+
@staticmethod
168+
def full(shape, value):
169+
raise NotImplementedError
150170

151171

152172
class NumpyBackend(ArrayBackend):
@@ -253,6 +273,28 @@ def pack_padded_sequence(array, mask):
253273

254274
new_shape = (shape[0] * shape[1],) + shape[2:]
255275
return array.reshape(new_shape, order='F')[mask.flatten(order='F')]
276+
277+
@staticmethod
278+
def flatten(array):
279+
shape = array.shape
280+
new_shape = (shape[0] * shape[1],) + shape[2:]
281+
return array.reshape(new_shape, order='F')
282+
283+
@staticmethod
284+
def empty(shape, device=None):
285+
return np.empty(shape)
286+
287+
@staticmethod
288+
def none():
289+
return np.nan
290+
291+
@staticmethod
292+
def shape(array):
293+
return array.shape
294+
295+
@staticmethod
296+
def full(shape, value):
297+
return np.full(shape, value)
256298

257299

258300
class TorchBackend(ArrayBackend):
@@ -364,9 +406,31 @@ def pack_padded_sequence(array, mask):
364406
shape = array.shape
365407

366408
new_shape = (shape[0]*shape[1], ) + shape[2:]
367-
409+
368410
return array.transpose(0, 1).reshape(new_shape)[mask.transpose(0, 1).flatten()]
369411

412+
@staticmethod
413+
def flatten(array):
414+
shape = array.shape
415+
new_shape = (shape[0]*shape[1], ) + shape[2:]
416+
return array.transpose(0, 1).reshape(new_shape)
417+
418+
@staticmethod
419+
def empty(shape, device=None):
420+
device = TorchUtils.get_device() if device is None else device
421+
return torch.empty(shape, device=device)
422+
423+
@staticmethod
424+
def none():
425+
return torch.nan
426+
427+
@staticmethod
428+
def shape(array):
429+
return array.shape
430+
431+
@staticmethod
432+
def full(shape, value):
433+
return torch.full(shape, value)
370434

371435
class ListBackend(ArrayBackend):
372436

@@ -421,3 +485,23 @@ def from_list(array):
421485
@staticmethod
422486
def pack_padded_sequence(array, mask):
423487
return NumpyBackend.pack_padded_sequence(array, np.array(mask))
488+
489+
@staticmethod
490+
def flatten(array):
491+
return NumpyBackend.flatten(array)
492+
493+
@staticmethod
494+
def empty(shape, device=None):
495+
return np.empty(shape)
496+
497+
@staticmethod
498+
def none():
499+
return None
500+
501+
@staticmethod
502+
def shape(array):
503+
return np.array(array).shape
504+
505+
@staticmethod
506+
def full(shape, value):
507+
return np.full(shape, value)

mushroom_rl/core/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat
128128

129129
self._end(record)
130130

131+
dataset.info.parse()
132+
dataset.episode_info.parse()
131133
return dataset
132134

133135
def _step(self, render, record):

mushroom_rl/core/dataset.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from mushroom_rl.core.serialization import Serializable
88
from .array_backend import ArrayBackend
9+
from .extra_info import ExtraInfo
910

1011
from ._impl import *
1112

@@ -103,8 +104,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
103104
else:
104105
policy_state_shape = None
105106

106-
self._info = defaultdict(list)
107-
self._episode_info = defaultdict(list)
107+
self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
108+
self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
108109
self._theta_list = list()
109110

110111
if dataset_info.backend == 'numpy':
@@ -195,12 +196,12 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
195196
dataset = cls.create_raw_instance()
196197

197198
if info is None:
198-
dataset._info = defaultdict(list)
199+
dataset._info = ExtraInfo(1, backend)
199200
else:
200201
dataset._info = info.copy()
201202

202203
if episode_info is None:
203-
dataset._episode_info = defaultdict(list)
204+
dataset._episode_info = ExtraInfo(1, backend)
204205
else:
205206
dataset._episode_info = episode_info.copy()
206207

@@ -228,7 +229,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
228229

229230
def append(self, step, info):
230231
self._data.append(*step)
231-
self._append_info(self._info, info)
232+
self._info.append(info)
232233

233234
def append_episode_info(self, info):
234235
self._append_info(self._episode_info, info)
@@ -243,21 +244,17 @@ def get_info(self, field, index=None):
243244
return self._info[field][index]
244245

245246
def clear(self):
246-
self._episode_info = defaultdict(list)
247+
self._episode_info.clear()
247248
self._theta_list = list()
248-
self._info = defaultdict(list)
249+
self._info.clear()
249250

250251
self._data.clear()
251252

252253
def get_view(self, index, copy=False):
253254
dataset = self.create_raw_instance(dataset=self)
254255

255-
info_slice = defaultdict(list)
256-
for key in self._info.keys():
257-
info_slice[key] = self._info[key][index]
258-
259-
dataset._info = info_slice
260-
dataset._episode_info = defaultdict(list)
256+
dataset._info = self._info.get_view(index, copy)
257+
dataset._episode_info = self._episode_info.get_view(index, copy)
261258
dataset._data = self._data.get_view(index, copy)
262259

263260
return dataset
@@ -276,11 +273,9 @@ def __getitem__(self, index):
276273

277274
def __add__(self, other):
278275
result = self.create_raw_instance(dataset=self)
279-
new_info = self._merge_info(self.info, other.info)
280-
new_episode_info = self._merge_info(self.episode_info, other.episode_info)
281276

282-
result._info = new_info
283-
result._episode_info = new_episode_info
277+
result._info = self._info + other._info
278+
result._episode_info = self._episode_info + other._episode_info
284279
result._theta_list = self._theta_list + other._theta_list
285280
result._data = self._data + other._data
286281

@@ -525,8 +520,8 @@ def _convert(self, *arrays, to='numpy'):
525520

526521
def _add_all_save_attr(self):
527522
self._add_save_attr(
528-
_info='pickle',
529-
_episode_info='pickle',
523+
_info='mushroom',
524+
_episode_info='mushroom',
530525
_theta_list='pickle',
531526
_data='mushroom',
532527
_array_backend='primitive',
@@ -557,7 +552,7 @@ def append(self, step, info):
557552

558553
def append_vectorized(self, step, info, mask):
559554
self._data.append(*step, mask=mask)
560-
self._append_info(self._info, {}) # FIXME: handle properly info
555+
self._info.append(info)
561556

562557
def append_theta_vectorized(self, theta, mask):
563558
for i in range(len(theta)):
@@ -581,11 +576,16 @@ def clear(self, n_steps_per_fit=None):
581576
mask.flatten()[n_extra_steps:] = False
582577
residual_data.mask = mask.reshape(original_shape)
583578

579+
residual_info = self._info.get_view(view_size, copy=True)
580+
residual_episode_info = self._episode_info.get_view(view_size, copy=True)
581+
584582
super().clear()
585583
self._initialize_theta_list(n_envs)
586584

587585
if n_steps_per_fit is not None and residual_data is not None:
588586
self._data = residual_data
587+
self._info = residual_info
588+
self._episode_info = residual_episode_info
589589

590590
def flatten(self, n_steps_per_fit=None):
591591
if len(self) == 0:
@@ -622,9 +622,12 @@ def flatten(self, n_steps_per_fit=None):
622622

623623
flat_theta_list = self._flatten_theta_list()
624624

625+
flat_info = self._info.flatten(self.mask)
626+
flat_episode_info = self._episode_info.flatten(self.mask)
627+
625628
return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
626629
policy_state=policy_state, policy_next_state=policy_next_state,
627-
info=None, episode_info=None, theta_list=flat_theta_list, # FIXME: handle properly info
630+
info=flat_info, episode_info=flat_episode_info, theta_list=flat_theta_list,
628631
horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma,
629632
backend=self._array_backend.get_backend_name())
630633

0 commit comments

Comments
 (0)