diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8be0d1c9a85..d7c25928f30 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1569,15 +1569,10 @@ def _step_and_maybe_reset_no_buffers( results = [None] * len(workers_range) - consumed_indices = [] - events = set(workers_range) - while len(consumed_indices) < len(workers_range): - for i in list(events): - if self._events[i].is_set(): - results[i] = self.parent_channels[i].recv() - self._events[i].clear() - consumed_indices.append(i) - events.discard(i) + self._wait_for_workers(workers_range) + + for i, w in enumerate(workers_range): + results[i] = self.parent_channels[w].recv() out_next, out_root = zip(*(future for future in results)) out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 2518ceab742..c7377a84d9e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -9,7 +9,7 @@ import warnings from copy import deepcopy from functools import partial, wraps -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import Any, Callable, Iterator import numpy as np import torch @@ -476,7 +476,7 @@ def __init__( self, *, device: DEVICE_TYPING = None, - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, spec_locked: bool = True, @@ -587,10 +587,10 @@ def auto_specs_( policy: Callable[[TensorDictBase], TensorDictBase], *, tensordict: TensorDictBase | None = None, - action_key: NestedKey | List[NestedKey] = "action", - done_key: NestedKey | List[NestedKey] | None = None, - observation_key: NestedKey | List[NestedKey] = "observation", - reward_key: NestedKey | List[NestedKey] = "reward", + action_key: NestedKey | list[NestedKey] = "action", + done_key: NestedKey | list[NestedKey] | None = None, + observation_key: NestedKey | list[NestedKey] = "observation", + reward_key: NestedKey | list[NestedKey] = "reward", ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -692,7 +692,7 @@ def auto_specs_( if full_action_spec is not None: self.full_action_spec = full_action_spec if full_done_spec is not None: - self.full_done_specs = full_done_spec + self.full_done_spec = full_done_spec if full_observation_spec is not None: self.full_observation_spec = full_observation_spec if full_reward_spec is not None: @@ -704,8 +704,7 @@ def auto_specs_( @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): - return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs) - kwargs["return_contiguous"] = return_contiguous + kwargs.setdefault("return_contiguous", not self._has_dynamic_specs) return check_env_specs_func(self, *args, **kwargs) check_env_specs.__doc__ = check_env_specs_func.__doc__ @@ -850,8 +849,7 @@ def ndim(self): def append_transform( self, - transform: "Transform" # noqa: F821 - | Callable[[TensorDictBase], TensorDictBase], + transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821 ) -> EnvBase: """Returns a transformed environment where the callable/transform passed is applied. @@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None: @property @_cache_value - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: """The action keys of an environment. By default, there will only be one key named "action". @@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]: @property @_cache_value - def state_keys(self) -> List[NestedKey]: + def state_keys(self) -> list[NestedKey]: """The state keys of an environment. By default, there will only be one key named "state". @@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None: # Reward spec @property @_cache_value - def reward_keys(self) -> List[NestedKey]: + def reward_keys(self) -> list[NestedKey]: """The reward keys of an environment. By default, there will only be one key named "reward". @@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]: @property @_cache_value - def observation_keys(self) -> List[NestedKey]: + def observation_keys(self) -> list[NestedKey]: """The observation keys of an environment. By default, there will only be one key named "observation". @@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None: # done spec @property @_cache_value - def done_keys(self) -> List[NestedKey]: + def done_keys(self) -> list[NestedKey]: """The done keys of an environment. By default, there will only be one key named "done". @@ -2205,8 +2203,8 @@ def register_gym( id: str, *, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, backend: str = None, to_numpy: bool = False, reward_threshold: float | None = None, @@ -2395,8 +2393,8 @@ def _register_gym( cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2437,8 +2435,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2485,8 +2483,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2538,8 +2536,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2594,8 +2592,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2652,8 +2650,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2710,7 +2708,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: def reset( self, - tensordict: Optional[TensorDictBase] = None, + tensordict: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: """Resets the environment. @@ -2819,8 +2817,8 @@ def numel(self) -> int: return prod(self.batch_size) def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: @@ -2841,7 +2839,7 @@ def set_seed( return seed @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): raise NotImplementedError def set_state(self): @@ -2856,9 +2854,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: f"got {tensordict.batch_size} and {self.batch_size}" ) - def all_actions( - self, tensordict: Optional[TensorDictBase] = None - ) -> TensorDictBase: + def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: """Generates all possible actions from the action spec. This only works in environments with fully discrete actions. @@ -2877,7 +2873,7 @@ def all_actions( return self.full_action_spec.enumerate(use_mask=True) - def rand_action(self, tensordict: Optional[TensorDictBase] = None): + def rand_action(self, tensordict: TensorDictBase | None = None): """Performs a random action given the action_spec attribute. Args: @@ -2911,7 +2907,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): tensordict.update(r) return tensordict - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: + def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: """Performs a random step in the environment given the action_spec attribute. Args: @@ -2947,15 +2943,15 @@ def _has_dynamic_specs(self) -> bool: def rollout( self, max_steps: int, - policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, + callback: Callable[[TensorDictBase, ...], Any] | None = None, *, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool | None = None, break_when_all_done: bool | None = None, return_contiguous: bool | None = False, - tensordict: Optional[TensorDictBase] = None, + tensordict: TensorDictBase | None = None, set_truncated: bool = False, out=None, trust_policy: bool = False, @@ -3485,7 +3481,7 @@ def _rollout_nonstop( def step_and_maybe_reset( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: """Runs a step in the environment and (partially) resets it if needed. Args: @@ -3606,7 +3602,7 @@ def empty_cache(self): @property @_cache_value - def reset_keys(self) -> List[NestedKey]: + def reset_keys(self) -> list[NestedKey]: """Returns a list of reset keys. Reset keys are keys that indicate partial reset, in batched, multitask or multiagent @@ -3763,14 +3759,14 @@ class _EnvWrapper(EnvBase): """ git_url: str = "" - available_envs: Dict[str, Any] = {} + available_envs: dict[str, Any] = {} libname: str = "" def __init__( self, *args, device: DEVICE_TYPING = None, - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, allow_done_after_reset: bool = False, spec_locked: bool = True, **kwargs, @@ -3819,7 +3815,7 @@ def _sync_device(self): return sync_func @abc.abstractmethod - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): raise NotImplementedError def __getattr__(self, attr: str) -> Any: @@ -3845,7 +3841,7 @@ def __getattr__(self, attr: str) -> Any: ) @abc.abstractmethod - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: """Runs all the necessary steps such that the environment is ready to use. This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment @@ -3859,7 +3855,7 @@ def _init_env(self) -> Optional[int]: raise NotImplementedError @abc.abstractmethod - def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821 + def _build_env(self, **kwargs) -> gym.Env: # noqa: F821 """Creates an environment from the target library and stores it with the `_env` attribute. When overwritten, this function should pass all the required kwargs to the env instantiation method. @@ -3868,7 +3864,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821 raise NotImplementedError @abc.abstractmethod - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 raise NotImplementedError def close(self, *, raise_if_closed: bool = True) -> None: @@ -3882,7 +3878,7 @@ def close(self, *, raise_if_closed: bool = True) -> None: def make_tensordict( env: _EnvWrapper, - policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, + policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None, ) -> TensorDictBase: """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.