diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f1bc3d3f386..b374a4594e8 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List +from typing import Any import torch @@ -329,9 +329,9 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: NestedKey | List[NestedKey] = "reward", - done_keys: NestedKey | List[NestedKey] = "done", - action_keys: NestedKey | List[NestedKey] = "action", + reward_keys: NestedKey | list[NestedKey] = "reward", + done_keys: NestedKey | list[NestedKey] = "done", + action_keys: NestedKey | list[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. @@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype): def check_env_specs( - env, - return_contiguous=True, + env: torchrl.envs.EnvBase, # noqa + return_contiguous: bool | None = None, check_dtype=True, seed: int | None = None, tensordict: TensorDictBase | None = None, @@ -700,7 +700,7 @@ def check_env_specs( env (EnvBase): the env for which the specs have to be checked against data. return_contiguous (bool, optional): if ``True``, the random rollout will be called with return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes - of inputs/outputs). Defaults to True. + of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs). check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. seed (int, optional): for reproducibility, a seed can be set. @@ -718,6 +718,8 @@ def check_env_specs( of an experiment and as such should be kept out of training scripts. """ + if return_contiguous is None: + return_contiguous = not env._has_dynamic_specs if break_when_any_done == "both": check_env_specs( env, @@ -746,7 +748,7 @@ def check_env_specs( ) fake_tensordict = env.fake_tensordict() - if not env._batch_locked and tensordict is not None: + if not env.batch_locked and tensordict is not None: shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape) fake_tensordict = fake_tensordict.expand(shape) tensordict = tensordict.expand(shape) @@ -786,10 +788,13 @@ def check_env_specs( - List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}. """ ) - zeroing_err_msg = ( - "zeroing the two tensordicts did not make them identical. " - f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" - ) + + def zeroing_err_msg(): + return ( + "zeroing the two tensordicts did not make them identical. " + f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" + ) + from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): @@ -799,7 +804,7 @@ def check_env_specs( ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): - raise AssertionError(zeroing_err_msg) + raise AssertionError(zeroing_err_msg()) # Checks shapes and eventually dtypes of keys at all nesting levels _per_level_env_check(fake, real, check_dtype=check_dtype) @@ -809,7 +814,7 @@ def check_env_specs( torch.zeros_like(fake_tensordict_select) != torch.zeros_like(real_tensordict_select) ).any(): - raise AssertionError(zeroing_err_msg) + raise AssertionError(zeroing_err_msg()) # Checks shapes and eventually dtypes of keys at all nesting levels _per_level_env_check( @@ -1028,14 +1033,14 @@ class MarlGroupMapType(Enum): ALL_IN_ONE_GROUP = 1 ONE_GROUP_PER_AGENT = 2 - def get_group_map(self, agent_names: List[str]): + def get_group_map(self, agent_names: list[str]): if self == MarlGroupMapType.ALL_IN_ONE_GROUP: return {"agents": agent_names} elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT: return {agent_name: [agent_name] for agent_name in agent_names} -def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]): +def check_marl_grouping(group_map: dict[str, list[str]], agent_names: list[str]): """Check MARL group map. Performs checks on the group map of a marl environment to assess its validity. @@ -1379,7 +1384,7 @@ def skim_through(td, reset=reset): def _update_during_reset( tensordict_reset: TensorDictBase, tensordict: TensorDictBase, - reset_keys: List[NestedKey], + reset_keys: list[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" if not reset_keys: