66
77from  mushroom_rl .core .serialization  import  Serializable 
88from  .array_backend  import  ArrayBackend 
9+ from  .extra_info  import  ExtraInfo 
910
1011from  ._impl  import  * 
1112
@@ -103,8 +104,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
103104        else :
104105            policy_state_shape  =  None 
105106
106-         self ._info  =  defaultdict ( list )
107-         self ._episode_info  =  defaultdict ( list )
107+         self ._info  =  ExtraInfo ( dataset_info . n_envs ,  dataset_info . backend ,  dataset_info . device )
108+         self ._episode_info  =  ExtraInfo ( dataset_info . n_envs ,  dataset_info . backend ,  dataset_info . device )
108109        self ._theta_list  =  list ()
109110
110111        if  dataset_info .backend  ==  'numpy' :
@@ -195,12 +196,12 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
195196        dataset  =  cls .create_raw_instance ()
196197
197198        if  info  is  None :
198-             dataset ._info  =  defaultdict ( list )
199+             dataset ._info  =  ExtraInfo ( 1 ,  backend )
199200        else :
200201            dataset ._info  =  info .copy ()
201202
202203        if  episode_info  is  None :
203-             dataset ._episode_info  =  defaultdict ( list )
204+             dataset ._episode_info  =  ExtraInfo ( 1 ,  backend )
204205        else :
205206            dataset ._episode_info  =  episode_info .copy ()
206207
@@ -228,7 +229,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
228229
229230    def  append (self , step , info ):
230231        self ._data .append (* step )
231-         self ._append_info ( self . _info ,  info )
232+         self ._info . append ( info )
232233
233234    def  append_episode_info (self , info ):
234235        self ._append_info (self ._episode_info , info )
@@ -243,21 +244,17 @@ def get_info(self, field, index=None):
243244            return  self ._info [field ][index ]
244245
245246    def  clear (self ):
246-         self ._episode_info   =   defaultdict ( list )
247+         self ._episode_info . clear ( )
247248        self ._theta_list  =  list ()
248-         self ._info   =   defaultdict ( list )
249+         self ._info . clear ( )
249250
250251        self ._data .clear ()
251252
252253    def  get_view (self , index , copy = False ):
253254        dataset  =  self .create_raw_instance (dataset = self )
254255
255-         info_slice  =  defaultdict (list )
256-         for  key  in  self ._info .keys ():
257-             info_slice [key ] =  self ._info [key ][index ]
258- 
259-         dataset ._info  =  info_slice 
260-         dataset ._episode_info  =  defaultdict (list )
256+         dataset ._info  =  self ._info .get_view (index , copy )
257+         dataset ._episode_info  =  self ._episode_info .get_view (index , copy )
261258        dataset ._data  =  self ._data .get_view (index , copy )
262259
263260        return  dataset 
@@ -276,11 +273,9 @@ def __getitem__(self, index):
276273
277274    def  __add__ (self , other ):
278275        result  =  self .create_raw_instance (dataset = self )
279-         new_info  =  self ._merge_info (self .info , other .info )
280-         new_episode_info  =  self ._merge_info (self .episode_info , other .episode_info )
281276
282-         result ._info  =  new_info 
283-         result ._episode_info  =  new_episode_info 
277+         result ._info  =  self . _info   +   other . _info 
278+         result ._episode_info  =  self . _episode_info   +   other . _episode_info 
284279        result ._theta_list  =  self ._theta_list  +  other ._theta_list 
285280        result ._data  =  self ._data  +  other ._data 
286281
@@ -525,8 +520,8 @@ def _convert(self, *arrays, to='numpy'):
525520
526521    def  _add_all_save_attr (self ):
527522        self ._add_save_attr (
528-             _info = 'pickle ' ,
529-             _episode_info = 'pickle ' ,
523+             _info = 'mushroom ' ,
524+             _episode_info = 'mushroom ' ,
530525            _theta_list = 'pickle' ,
531526            _data = 'mushroom' ,
532527            _array_backend = 'primitive' ,
@@ -557,7 +552,7 @@ def append(self, step, info):
557552
558553    def  append_vectorized (self , step , info , mask ):
559554        self ._data .append (* step , mask = mask )
560-         self ._append_info ( self . _info , {})   # FIXME: handle properly  info
555+         self ._info . append ( info ) 
561556
562557    def  append_theta_vectorized (self , theta , mask ):
563558        for  i  in  range (len (theta )):
@@ -581,11 +576,16 @@ def clear(self, n_steps_per_fit=None):
581576                mask .flatten ()[n_extra_steps :] =  False 
582577                residual_data .mask  =  mask .reshape (original_shape )
583578
579+                 residual_info  =  self ._info .get_view (view_size , copy = True )
580+                 residual_episode_info  =  self ._episode_info .get_view (view_size , copy = True )
581+ 
584582        super ().clear ()
585583        self ._initialize_theta_list (n_envs )
586584
587585        if  n_steps_per_fit  is  not None  and  residual_data  is  not None :
588586            self ._data  =  residual_data 
587+             self ._info  =  residual_info 
588+             self ._episode_info  =  residual_episode_info 
589589
590590    def  flatten (self , n_steps_per_fit = None ):
591591        if  len (self ) ==  0 :
@@ -622,9 +622,12 @@ def flatten(self, n_steps_per_fit=None):
622622
623623        flat_theta_list  =  self ._flatten_theta_list ()
624624
625+         flat_info  =  self ._info .flatten (self .mask )
626+         flat_episode_info  =  self ._episode_info .flatten (self .mask )
627+ 
625628        return  Dataset .from_array (states , actions , rewards , next_states , absorbings , lasts ,
626629                                  policy_state = policy_state , policy_next_state = policy_next_state ,
627-                                   info = None , episode_info = None , theta_list = flat_theta_list ,   # FIXME: handle properly info 
630+                                   info = flat_info , episode_info = flat_episode_info , theta_list = flat_theta_list ,
628631                                  horizon = self ._dataset_info .horizon , gamma = self ._dataset_info .gamma ,
629632                                  backend = self ._array_backend .get_backend_name ())
630633
0 commit comments