Skip to content

Commit d556726

Browse files
author
Vincent Moens
committed
[BugFix] Fix env.full_done_spec~s~
ghstack-source-id: ba0d371 Pull Request resolved: #2815 (cherry picked from commit f5c0666)
1 parent 2879a76 commit d556726

File tree

2 files changed

+82
-56
lines changed

2 files changed

+82
-56
lines changed

torchrl/envs/batched_envs.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,15 +1513,10 @@ def _step_and_maybe_reset_no_buffers(
15131513

15141514
results = [None] * len(workers_range)
15151515

1516-
consumed_indices = []
1517-
events = set(workers_range)
1518-
while len(consumed_indices) < len(workers_range):
1519-
for i in list(events):
1520-
if self._events[i].is_set():
1521-
results[i] = self.parent_channels[i].recv()
1522-
self._events[i].clear()
1523-
consumed_indices.append(i)
1524-
events.discard(i)
1516+
self._wait_for_workers(workers_range)
1517+
1518+
for i, w in enumerate(workers_range):
1519+
results[i] = self.parent_channels[w].recv()
15251520

15261521
out_next, out_root = zip(*(future for future in results))
15271522
out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(

torchrl/envs/common.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from functools import partial, wraps
12-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
12+
from typing import Any, Callable, Iterator
1313

1414
import numpy as np
1515
import torch
@@ -457,7 +457,7 @@ def __init__(
457457
self,
458458
*,
459459
device: DEVICE_TYPING = None,
460-
batch_size: Optional[torch.Size] = None,
460+
batch_size: torch.Size | None = None,
461461
run_type_checks: bool = False,
462462
allow_done_after_reset: bool = False,
463463
spec_locked: bool = True,
@@ -568,10 +568,10 @@ def auto_specs_(
568568
policy: Callable[[TensorDictBase], TensorDictBase],
569569
*,
570570
tensordict: TensorDictBase | None = None,
571-
action_key: NestedKey | List[NestedKey] = "action",
572-
done_key: NestedKey | List[NestedKey] | None = None,
573-
observation_key: NestedKey | List[NestedKey] = "observation",
574-
reward_key: NestedKey | List[NestedKey] = "reward",
571+
action_key: NestedKey | list[NestedKey] = "action",
572+
done_key: NestedKey | list[NestedKey] | None = None,
573+
observation_key: NestedKey | list[NestedKey] = "observation",
574+
reward_key: NestedKey | list[NestedKey] = "reward",
575575
):
576576
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
577577
@@ -673,7 +673,7 @@ def auto_specs_(
673673
if full_action_spec is not None:
674674
self.full_action_spec = full_action_spec
675675
if full_done_spec is not None:
676-
self.full_done_specs = full_done_spec
676+
self.full_done_spec = full_done_spec
677677
if full_observation_spec is not None:
678678
self.full_observation_spec = full_observation_spec
679679
if full_reward_spec is not None:
@@ -685,8 +685,7 @@ def auto_specs_(
685685

686686
@wraps(check_env_specs_func)
687687
def check_env_specs(self, *args, **kwargs):
688-
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
689-
kwargs["return_contiguous"] = return_contiguous
688+
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
690689
return check_env_specs_func(self, *args, **kwargs)
691690

692691
check_env_specs.__doc__ = check_env_specs_func.__doc__
@@ -831,8 +830,7 @@ def ndim(self):
831830

832831
def append_transform(
833832
self,
834-
transform: "Transform" # noqa: F821
835-
| Callable[[TensorDictBase], TensorDictBase],
833+
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
836834
) -> EnvBase:
837835
"""Returns a transformed environment where the callable/transform passed is applied.
838836
@@ -976,7 +974,7 @@ def output_spec(self, value: TensorSpec) -> None:
976974

977975
@property
978976
@_cache_value
979-
def action_keys(self) -> List[NestedKey]:
977+
def action_keys(self) -> list[NestedKey]:
980978
"""The action keys of an environment.
981979
982980
By default, there will only be one key named "action".
@@ -989,7 +987,7 @@ def action_keys(self) -> List[NestedKey]:
989987

990988
@property
991989
@_cache_value
992-
def state_keys(self) -> List[NestedKey]:
990+
def state_keys(self) -> list[NestedKey]:
993991
"""The state keys of an environment.
994992
995993
By default, there will only be one key named "state".
@@ -1186,7 +1184,7 @@ def full_action_spec(self, spec: Composite) -> None:
11861184
# Reward spec
11871185
@property
11881186
@_cache_value
1189-
def reward_keys(self) -> List[NestedKey]:
1187+
def reward_keys(self) -> list[NestedKey]:
11901188
"""The reward keys of an environment.
11911189
11921190
By default, there will only be one key named "reward".
@@ -1196,6 +1194,20 @@ def reward_keys(self) -> List[NestedKey]:
11961194
reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
11971195
return reward_keys
11981196

1197+
@property
1198+
@_cache_value
1199+
def observation_keys(self) -> list[NestedKey]:
1200+
"""The observation keys of an environment.
1201+
1202+
By default, there will only be one key named "observation".
1203+
1204+
Keys are sorted by depth in the data tree.
1205+
"""
1206+
observation_keys = sorted(
1207+
self.full_observation_spec.keys(True, True), key=_repr_by_depth
1208+
)
1209+
return observation_keys
1210+
11991211
@property
12001212
def reward_key(self):
12011213
"""The reward key of an environment.
@@ -1383,7 +1395,7 @@ def full_reward_spec(self, spec: Composite) -> None:
13831395
# done spec
13841396
@property
13851397
@_cache_value
1386-
def done_keys(self) -> List[NestedKey]:
1398+
def done_keys(self) -> list[NestedKey]:
13871399
"""The done keys of an environment.
13881400
13891401
By default, there will only be one key named "done".
@@ -2113,8 +2125,8 @@ def register_gym(
21132125
id: str,
21142126
*,
21152127
entry_point: Callable | None = None,
2116-
transform: "Transform" | None = None, # noqa: F821
2117-
info_keys: List[NestedKey] | None = None,
2128+
transform: Transform | None = None, # noqa: F821
2129+
info_keys: list[NestedKey] | None = None,
21182130
backend: str = None,
21192131
to_numpy: bool = False,
21202132
reward_threshold: float | None = None,
@@ -2303,8 +2315,8 @@ def _register_gym(
23032315
cls,
23042316
id,
23052317
entry_point: Callable | None = None,
2306-
transform: "Transform" | None = None, # noqa: F821
2307-
info_keys: List[NestedKey] | None = None,
2318+
transform: Transform | None = None, # noqa: F821
2319+
info_keys: list[NestedKey] | None = None,
23082320
to_numpy: bool = False,
23092321
reward_threshold: float | None = None,
23102322
nondeterministic: bool = False,
@@ -2345,8 +2357,8 @@ def _register_gym( # noqa: F811
23452357
cls,
23462358
id,
23472359
entry_point: Callable | None = None,
2348-
transform: "Transform" | None = None, # noqa: F821
2349-
info_keys: List[NestedKey] | None = None,
2360+
transform: Transform | None = None, # noqa: F821
2361+
info_keys: list[NestedKey] | None = None,
23502362
to_numpy: bool = False,
23512363
reward_threshold: float | None = None,
23522364
nondeterministic: bool = False,
@@ -2393,8 +2405,8 @@ def _register_gym( # noqa: F811
23932405
cls,
23942406
id,
23952407
entry_point: Callable | None = None,
2396-
transform: "Transform" | None = None, # noqa: F821
2397-
info_keys: List[NestedKey] | None = None,
2408+
transform: Transform | None = None, # noqa: F821
2409+
info_keys: list[NestedKey] | None = None,
23982410
to_numpy: bool = False,
23992411
reward_threshold: float | None = None,
24002412
nondeterministic: bool = False,
@@ -2446,8 +2458,8 @@ def _register_gym( # noqa: F811
24462458
cls,
24472459
id,
24482460
entry_point: Callable | None = None,
2449-
transform: "Transform" | None = None, # noqa: F821
2450-
info_keys: List[NestedKey] | None = None,
2461+
transform: Transform | None = None, # noqa: F821
2462+
info_keys: list[NestedKey] | None = None,
24512463
to_numpy: bool = False,
24522464
reward_threshold: float | None = None,
24532465
nondeterministic: bool = False,
@@ -2502,8 +2514,8 @@ def _register_gym( # noqa: F811
25022514
cls,
25032515
id,
25042516
entry_point: Callable | None = None,
2505-
transform: "Transform" | None = None, # noqa: F821
2506-
info_keys: List[NestedKey] | None = None,
2517+
transform: Transform | None = None, # noqa: F821
2518+
info_keys: list[NestedKey] | None = None,
25072519
to_numpy: bool = False,
25082520
reward_threshold: float | None = None,
25092521
nondeterministic: bool = False,
@@ -2560,8 +2572,8 @@ def _register_gym( # noqa: F811
25602572
cls,
25612573
id,
25622574
entry_point: Callable | None = None,
2563-
transform: "Transform" | None = None, # noqa: F821
2564-
info_keys: List[NestedKey] | None = None,
2575+
transform: Transform | None = None, # noqa: F821
2576+
info_keys: list[NestedKey] | None = None,
25652577
to_numpy: bool = False,
25662578
reward_threshold: float | None = None,
25672579
nondeterministic: bool = False,
@@ -2618,7 +2630,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
26182630

26192631
def reset(
26202632
self,
2621-
tensordict: Optional[TensorDictBase] = None,
2633+
tensordict: TensorDictBase | None = None,
26222634
**kwargs,
26232635
) -> TensorDictBase:
26242636
"""Resets the environment.
@@ -2727,8 +2739,8 @@ def numel(self) -> int:
27272739
return prod(self.batch_size)
27282740

27292741
def set_seed(
2730-
self, seed: Optional[int] = None, static_seed: bool = False
2731-
) -> Optional[int]:
2742+
self, seed: int | None = None, static_seed: bool = False
2743+
) -> int | None:
27322744
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
27332745
27342746
Args:
@@ -2749,7 +2761,7 @@ def set_seed(
27492761
return seed
27502762

27512763
@abc.abstractmethod
2752-
def _set_seed(self, seed: Optional[int]):
2764+
def _set_seed(self, seed: int | None):
27532765
raise NotImplementedError
27542766

27552767
def set_state(self):
@@ -2764,7 +2776,26 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
27642776
f"got {tensordict.batch_size} and {self.batch_size}"
27652777
)
27662778

2767-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2779+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
2780+
"""Generates all possible actions from the action spec.
2781+
2782+
This only works in environments with fully discrete actions.
2783+
2784+
Args:
2785+
tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2786+
is called with this tensordict.
2787+
2788+
Returns:
2789+
a tensordict object with the "action" entry updated with a batch of
2790+
all possible actions. The actions are stacked together in the
2791+
leading dimension.
2792+
"""
2793+
if tensordict is not None:
2794+
self.reset(tensordict)
2795+
2796+
return self.full_action_spec.enumerate(use_mask=True)
2797+
2798+
def rand_action(self, tensordict: TensorDictBase | None = None):
27682799
"""Performs a random action given the action_spec attribute.
27692800
27702801
Args:
@@ -2798,7 +2829,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
27982829
tensordict.update(r)
27992830
return tensordict
28002831

2801-
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
2832+
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
28022833
"""Performs a random step in the environment given the action_spec attribute.
28032834
28042835
Args:
@@ -2834,15 +2865,15 @@ def _has_dynamic_specs(self) -> bool:
28342865
def rollout(
28352866
self,
28362867
max_steps: int,
2837-
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
2838-
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
2868+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
2869+
callback: Callable[[TensorDictBase, ...], Any] | None = None,
28392870
*,
28402871
auto_reset: bool = True,
28412872
auto_cast_to_device: bool = False,
28422873
break_when_any_done: bool | None = None,
28432874
break_when_all_done: bool | None = None,
28442875
return_contiguous: bool | None = False,
2845-
tensordict: Optional[TensorDictBase] = None,
2876+
tensordict: TensorDictBase | None = None,
28462877
set_truncated: bool = False,
28472878
out=None,
28482879
trust_policy: bool = False,
@@ -3364,7 +3395,7 @@ def _rollout_nonstop(
33643395

33653396
def step_and_maybe_reset(
33663397
self, tensordict: TensorDictBase
3367-
) -> Tuple[TensorDictBase, TensorDictBase]:
3398+
) -> tuple[TensorDictBase, TensorDictBase]:
33683399
"""Runs a step in the environment and (partially) resets it if needed.
33693400
33703401
Args:
@@ -3472,7 +3503,7 @@ def empty_cache(self):
34723503

34733504
@property
34743505
@_cache_value
3475-
def reset_keys(self) -> List[NestedKey]:
3506+
def reset_keys(self) -> list[NestedKey]:
34763507
"""Returns a list of reset keys.
34773508
34783509
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3629,14 +3660,14 @@ class _EnvWrapper(EnvBase):
36293660
"""
36303661

36313662
git_url: str = ""
3632-
available_envs: Dict[str, Any] = {}
3663+
available_envs: dict[str, Any] = {}
36333664
libname: str = ""
36343665

36353666
def __init__(
36363667
self,
36373668
*args,
36383669
device: DEVICE_TYPING = None,
3639-
batch_size: Optional[torch.Size] = None,
3670+
batch_size: torch.Size | None = None,
36403671
allow_done_after_reset: bool = False,
36413672
spec_locked: bool = True,
36423673
**kwargs,
@@ -3685,7 +3716,7 @@ def _sync_device(self):
36853716
return sync_func
36863717

36873718
@abc.abstractmethod
3688-
def _check_kwargs(self, kwargs: Dict):
3719+
def _check_kwargs(self, kwargs: dict):
36893720
raise NotImplementedError
36903721

36913722
def __getattr__(self, attr: str) -> Any:
@@ -3711,7 +3742,7 @@ def __getattr__(self, attr: str) -> Any:
37113742
)
37123743

37133744
@abc.abstractmethod
3714-
def _init_env(self) -> Optional[int]:
3745+
def _init_env(self) -> int | None:
37153746
"""Runs all the necessary steps such that the environment is ready to use.
37163747
37173748
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3725,7 +3756,7 @@ def _init_env(self) -> Optional[int]:
37253756
raise NotImplementedError
37263757

37273758
@abc.abstractmethod
3728-
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3759+
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
37293760
"""Creates an environment from the target library and stores it with the `_env` attribute.
37303761
37313762
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3734,7 +3765,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
37343765
raise NotImplementedError
37353766

37363767
@abc.abstractmethod
3737-
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
3768+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
37383769
raise NotImplementedError
37393770

37403771
def close(self) -> None:
@@ -3748,7 +3779,7 @@ def close(self) -> None:
37483779

37493780
def make_tensordict(
37503781
env: _EnvWrapper,
3751-
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
3782+
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
37523783
) -> TensorDictBase:
37533784
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
37543785

0 commit comments

Comments
 (0)