File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -465,13 +465,22 @@ def _is_batched(self):
465
465
self ._env , tuple_of_classes + (gym_backend ("vector" ).VectorEnv ,)
466
466
)
467
467
468
+ @implement_for ("gym" , None , "0.27" )
468
469
def _get_batch_size (self , env ):
469
470
if hasattr (env , "num_envs" ):
470
471
batch_size = torch .Size ([env .num_envs , * self .batch_size ])
471
472
else :
472
473
batch_size = self .batch_size
473
474
return batch_size
474
475
476
+ @implement_for ("gymnasium" , "0.27" , None ) # gymnasium wants the unwrapped env
477
+ def _get_batch_size (self , env ): # noqa: F811
478
+ if hasattr (env , "num_envs" ):
479
+ batch_size = torch .Size ([env .unwrapped .num_envs , * self .batch_size ])
480
+ else :
481
+ batch_size = self .batch_size
482
+ return batch_size
483
+
475
484
def _check_kwargs (self , kwargs : Dict ):
476
485
if "env" not in kwargs :
477
486
raise TypeError ("Could not find environment key 'env' in kwargs." )
You can’t perform that action at this time.
0 commit comments