Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 574004c

Browse files
authored
Merge pull request #68 from facebookresearch/samvelyan/seeding
Fixing the seeding issue
2 parents 2054e7f + 36b174f commit 574004c

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

minihack/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,14 @@ def __init__(
173173
environment as a dictionary. Defaults to
174174
``minihack.base.MH_DEFAULT_OBS_KEYS``.
175175
seeds (list or None):
176-
A list of random seeds for sampling episodes. If none, the
177-
entire level distribution is used. Defaults to None.
176+
A list of integers used as level seeds for sampling
177+
episodes. The reset()` function samples a seed from this list
178+
uniformly at random and uses it for setting the level.
179+
When the ``sample_seed`` argument of the reset function is
180+
set to False, a random level will not be sampled from this list
181+
during environment resetting.
182+
If None, the entire level distribution is used.
183+
Defaults to None.
178184
penalty_mode (str):
179185
The name of the mode for calculating the time step penalty.
180186
Can be ``constant``, ``exp``, ``square``, ``linear``, or
@@ -319,10 +325,10 @@ def _get_obs_space_dict(self, space_dict):
319325

320326
return obs_space_dict
321327

322-
def reset(self, *args, **kwargs):
328+
def reset(self, *args, sample_seed=True, **kwargs):
323329
if self.reward_manager is not None:
324330
self.reward_manager.reset()
325-
if self._level_seeds is not None:
331+
if sample_seed and self._level_seeds is not None:
326332
seed = random.choice(self._level_seeds)
327333
self.seed(seed, seed, reseed=False)
328334
return super().reset(*args, **kwargs)

0 commit comments

Comments
 (0)