@@ -173,8 +173,14 @@ def __init__(
173
173
environment as a dictionary. Defaults to
174
174
``minihack.base.MH_DEFAULT_OBS_KEYS``.
175
175
seeds (list or None):
176
- A list of random seeds for sampling episodes. If none, the
177
- entire level distribution is used. Defaults to None.
176
+ A list of integers used as level seeds for sampling
177
+ episodes. The reset()` function samples a seed from this list
178
+ uniformly at random and uses it for setting the level.
179
+ When the ``sample_seed`` argument of the reset function is
180
+ set to False, a random level will not be sampled from this list
181
+ during environment resetting.
182
+ If None, the entire level distribution is used.
183
+ Defaults to None.
178
184
penalty_mode (str):
179
185
The name of the mode for calculating the time step penalty.
180
186
Can be ``constant``, ``exp``, ``square``, ``linear``, or
@@ -319,10 +325,10 @@ def _get_obs_space_dict(self, space_dict):
319
325
320
326
return obs_space_dict
321
327
322
- def reset (self , * args , ** kwargs ):
328
+ def reset (self , * args , sample_seed = True , ** kwargs ):
323
329
if self .reward_manager is not None :
324
330
self .reward_manager .reset ()
325
- if self ._level_seeds is not None :
331
+ if sample_seed and self ._level_seeds is not None :
326
332
seed = random .choice (self ._level_seeds )
327
333
self .seed (seed , seed , reseed = False )
328
334
return super ().reset (* args , ** kwargs )
0 commit comments