@@ -372,9 +372,9 @@ def _check_kwargs(self, kwargs: Dict):
372
372
def _init_env (self ) -> Optional [int ]:
373
373
# Add info
374
374
if self .parallel :
375
- _ , info_dict = self ._reset_parallel ()
375
+ _ , info_dict = self ._reset_parallel (seed = self . seed )
376
376
else :
377
- _ , info_dict = self ._reset_aec ()
377
+ _ , info_dict = self ._reset_aec (seed = self . seed )
378
378
379
379
for group , agents in self .group_map .items ():
380
380
info_specs = []
@@ -440,19 +440,20 @@ def _init_env(self) -> Optional[int]:
440
440
self .cached_step_output_zero .update (self .output_spec ["full_reward_spec" ].zero ())
441
441
self .cached_step_output_zero .update (self .output_spec ["full_done_spec" ].zero ())
442
442
443
- def _set_seed (self , seed : Optional [ int ] ):
443
+ def _set_seed (self , seed : int ):
444
444
self .seed = seed
445
+ self .reset (seed = self .seed )
445
446
446
447
def _reset (
447
448
self , tensordict : Optional [TensorDictBase ] = None , ** kwargs
448
449
) -> TensorDictBase :
449
450
450
451
if self .parallel :
451
452
# This resets when any is done
452
- observation_dict , info_dict = self ._reset_parallel ()
453
+ observation_dict , info_dict = self ._reset_parallel (** kwargs )
453
454
else :
454
455
# This resets when all are done
455
- observation_dict , info_dict = self ._reset_aec (tensordict )
456
+ observation_dict , info_dict = self ._reset_aec (tensordict , ** kwargs )
456
457
457
458
# We start with zeroed data and fill in the data for alive agents
458
459
tensordict_out = self .cached_reset_output_zero .clone ()
@@ -481,7 +482,7 @@ def _reset(
481
482
482
483
return tensordict_out
483
484
484
- def _reset_aec (self , tensordict = None ) -> Tuple [Dict , Dict ]:
485
+ def _reset_aec (self , tensordict = None , ** kwargs ) -> Tuple [Dict , Dict ]:
485
486
all_done = True
486
487
if tensordict is not None :
487
488
_resets = []
@@ -500,18 +501,16 @@ def _reset_aec(self, tensordict=None) -> Tuple[Dict, Dict]:
500
501
break
501
502
502
503
if all_done :
503
- self ._env .reset (seed = self . seed )
504
+ self ._env .reset (** kwargs )
504
505
505
506
observation_dict = {
506
507
agent : self ._env .observe (agent ) for agent in self .possible_agents
507
508
}
508
509
info_dict = self ._env .infos
509
510
return observation_dict , info_dict
510
511
511
- def _reset_parallel (
512
- self ,
513
- ) -> Tuple [Dict , Dict ]:
514
- return self ._env .reset (seed = self .seed )
512
+ def _reset_parallel (self , ** kwargs ) -> Tuple [Dict , Dict ]:
513
+ return self ._env .reset (** kwargs )
515
514
516
515
def _step (
517
516
self ,
0 commit comments