Skip to content

Commit f12b7cc

Browse files
[BugFix] PettingZoo seeding (#1554)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent 57f1220 commit f12b7cc

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

torchrl/envs/libs/pettingzoo.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,9 @@ def _check_kwargs(self, kwargs: Dict):
372372
def _init_env(self) -> Optional[int]:
373373
# Add info
374374
if self.parallel:
375-
_, info_dict = self._reset_parallel()
375+
_, info_dict = self._reset_parallel(seed=self.seed)
376376
else:
377-
_, info_dict = self._reset_aec()
377+
_, info_dict = self._reset_aec(seed=self.seed)
378378

379379
for group, agents in self.group_map.items():
380380
info_specs = []
@@ -440,19 +440,20 @@ def _init_env(self) -> Optional[int]:
440440
self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero())
441441
self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero())
442442

443-
def _set_seed(self, seed: Optional[int]):
443+
def _set_seed(self, seed: int):
444444
self.seed = seed
445+
self.reset(seed=self.seed)
445446

446447
def _reset(
447448
self, tensordict: Optional[TensorDictBase] = None, **kwargs
448449
) -> TensorDictBase:
449450

450451
if self.parallel:
451452
# This resets when any is done
452-
observation_dict, info_dict = self._reset_parallel()
453+
observation_dict, info_dict = self._reset_parallel(**kwargs)
453454
else:
454455
# This resets when all are done
455-
observation_dict, info_dict = self._reset_aec(tensordict)
456+
observation_dict, info_dict = self._reset_aec(tensordict, **kwargs)
456457

457458
# We start with zeroed data and fill in the data for alive agents
458459
tensordict_out = self.cached_reset_output_zero.clone()
@@ -481,7 +482,7 @@ def _reset(
481482

482483
return tensordict_out
483484

484-
def _reset_aec(self, tensordict=None) -> Tuple[Dict, Dict]:
485+
def _reset_aec(self, tensordict=None, **kwargs) -> Tuple[Dict, Dict]:
485486
all_done = True
486487
if tensordict is not None:
487488
_resets = []
@@ -500,18 +501,16 @@ def _reset_aec(self, tensordict=None) -> Tuple[Dict, Dict]:
500501
break
501502

502503
if all_done:
503-
self._env.reset(seed=self.seed)
504+
self._env.reset(**kwargs)
504505

505506
observation_dict = {
506507
agent: self._env.observe(agent) for agent in self.possible_agents
507508
}
508509
info_dict = self._env.infos
509510
return observation_dict, info_dict
510511

511-
def _reset_parallel(
512-
self,
513-
) -> Tuple[Dict, Dict]:
514-
return self._env.reset(seed=self.seed)
512+
def _reset_parallel(self, **kwargs) -> Tuple[Dict, Dict]:
513+
return self._env.reset(**kwargs)
515514

516515
def _step(
517516
self,

0 commit comments

Comments
 (0)