Skip to content

Commit 7f42576

Browse files
author
Vincent Moens
authored
[Minor] Remove ya gymnasium deprecation warning in vectorized envs (#1573)
1 parent 434fe58 commit 7f42576

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torchrl/envs/libs/gym.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,22 @@ def _is_batched(self):
465465
self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)
466466
)
467467

468+
@implement_for("gym", None, "0.27")
468469
def _get_batch_size(self, env):
469470
if hasattr(env, "num_envs"):
470471
batch_size = torch.Size([env.num_envs, *self.batch_size])
471472
else:
472473
batch_size = self.batch_size
473474
return batch_size
474475

476+
@implement_for("gymnasium", "0.27", None) # gymnasium wants the unwrapped env
477+
def _get_batch_size(self, env): # noqa: F811
478+
if hasattr(env, "num_envs"):
479+
batch_size = torch.Size([env.unwrapped.num_envs, *self.batch_size])
480+
else:
481+
batch_size = self.batch_size
482+
return batch_size
483+
475484
def _check_kwargs(self, kwargs: Dict):
476485
if "env" not in kwargs:
477486
raise TypeError("Could not find environment key 'env' in kwargs.")

0 commit comments

Comments
 (0)