9
9
import warnings
10
10
from copy import deepcopy
11
11
from functools import partial , wraps
12
- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple
12
+ from typing import Any , Callable , Iterator
13
13
14
14
import numpy as np
15
15
import torch
@@ -457,7 +457,7 @@ def __init__(
457
457
self ,
458
458
* ,
459
459
device : DEVICE_TYPING = None ,
460
- batch_size : Optional [ torch .Size ] = None ,
460
+ batch_size : torch .Size | None = None ,
461
461
run_type_checks : bool = False ,
462
462
allow_done_after_reset : bool = False ,
463
463
spec_locked : bool = True ,
@@ -568,10 +568,10 @@ def auto_specs_(
568
568
policy : Callable [[TensorDictBase ], TensorDictBase ],
569
569
* ,
570
570
tensordict : TensorDictBase | None = None ,
571
- action_key : NestedKey | List [NestedKey ] = "action" ,
572
- done_key : NestedKey | List [NestedKey ] | None = None ,
573
- observation_key : NestedKey | List [NestedKey ] = "observation" ,
574
- reward_key : NestedKey | List [NestedKey ] = "reward" ,
571
+ action_key : NestedKey | list [NestedKey ] = "action" ,
572
+ done_key : NestedKey | list [NestedKey ] | None = None ,
573
+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
574
+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
575
575
):
576
576
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
577
577
@@ -673,7 +673,7 @@ def auto_specs_(
673
673
if full_action_spec is not None :
674
674
self .full_action_spec = full_action_spec
675
675
if full_done_spec is not None :
676
- self .full_done_specs = full_done_spec
676
+ self .full_done_spec = full_done_spec
677
677
if full_observation_spec is not None :
678
678
self .full_observation_spec = full_observation_spec
679
679
if full_reward_spec is not None :
@@ -685,8 +685,7 @@ def auto_specs_(
685
685
686
686
@wraps (check_env_specs_func )
687
687
def check_env_specs (self , * args , ** kwargs ):
688
- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
689
- kwargs ["return_contiguous" ] = return_contiguous
688
+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
690
689
return check_env_specs_func (self , * args , ** kwargs )
691
690
692
691
check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -831,8 +830,7 @@ def ndim(self):
831
830
832
831
def append_transform (
833
832
self ,
834
- transform : "Transform" # noqa: F821
835
- | Callable [[TensorDictBase ], TensorDictBase ],
833
+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
836
834
) -> EnvBase :
837
835
"""Returns a transformed environment where the callable/transform passed is applied.
838
836
@@ -976,7 +974,7 @@ def output_spec(self, value: TensorSpec) -> None:
976
974
977
975
@property
978
976
@_cache_value
979
- def action_keys (self ) -> List [NestedKey ]:
977
+ def action_keys (self ) -> list [NestedKey ]:
980
978
"""The action keys of an environment.
981
979
982
980
By default, there will only be one key named "action".
@@ -989,7 +987,7 @@ def action_keys(self) -> List[NestedKey]:
989
987
990
988
@property
991
989
@_cache_value
992
- def state_keys (self ) -> List [NestedKey ]:
990
+ def state_keys (self ) -> list [NestedKey ]:
993
991
"""The state keys of an environment.
994
992
995
993
By default, there will only be one key named "state".
@@ -1186,7 +1184,7 @@ def full_action_spec(self, spec: Composite) -> None:
1186
1184
# Reward spec
1187
1185
@property
1188
1186
@_cache_value
1189
- def reward_keys (self ) -> List [NestedKey ]:
1187
+ def reward_keys (self ) -> list [NestedKey ]:
1190
1188
"""The reward keys of an environment.
1191
1189
1192
1190
By default, there will only be one key named "reward".
@@ -1196,6 +1194,20 @@ def reward_keys(self) -> List[NestedKey]:
1196
1194
reward_keys = sorted (self .full_reward_spec .keys (True , True ), key = _repr_by_depth )
1197
1195
return reward_keys
1198
1196
1197
+ @property
1198
+ @_cache_value
1199
+ def observation_keys (self ) -> list [NestedKey ]:
1200
+ """The observation keys of an environment.
1201
+
1202
+ By default, there will only be one key named "observation".
1203
+
1204
+ Keys are sorted by depth in the data tree.
1205
+ """
1206
+ observation_keys = sorted (
1207
+ self .full_observation_spec .keys (True , True ), key = _repr_by_depth
1208
+ )
1209
+ return observation_keys
1210
+
1199
1211
@property
1200
1212
def reward_key (self ):
1201
1213
"""The reward key of an environment.
@@ -1383,7 +1395,7 @@ def full_reward_spec(self, spec: Composite) -> None:
1383
1395
# done spec
1384
1396
@property
1385
1397
@_cache_value
1386
- def done_keys (self ) -> List [NestedKey ]:
1398
+ def done_keys (self ) -> list [NestedKey ]:
1387
1399
"""The done keys of an environment.
1388
1400
1389
1401
By default, there will only be one key named "done".
@@ -2113,8 +2125,8 @@ def register_gym(
2113
2125
id : str ,
2114
2126
* ,
2115
2127
entry_point : Callable | None = None ,
2116
- transform : " Transform" | None = None , # noqa: F821
2117
- info_keys : List [NestedKey ] | None = None ,
2128
+ transform : Transform | None = None , # noqa: F821
2129
+ info_keys : list [NestedKey ] | None = None ,
2118
2130
backend : str = None ,
2119
2131
to_numpy : bool = False ,
2120
2132
reward_threshold : float | None = None ,
@@ -2303,8 +2315,8 @@ def _register_gym(
2303
2315
cls ,
2304
2316
id ,
2305
2317
entry_point : Callable | None = None ,
2306
- transform : " Transform" | None = None , # noqa: F821
2307
- info_keys : List [NestedKey ] | None = None ,
2318
+ transform : Transform | None = None , # noqa: F821
2319
+ info_keys : list [NestedKey ] | None = None ,
2308
2320
to_numpy : bool = False ,
2309
2321
reward_threshold : float | None = None ,
2310
2322
nondeterministic : bool = False ,
@@ -2345,8 +2357,8 @@ def _register_gym( # noqa: F811
2345
2357
cls ,
2346
2358
id ,
2347
2359
entry_point : Callable | None = None ,
2348
- transform : " Transform" | None = None , # noqa: F821
2349
- info_keys : List [NestedKey ] | None = None ,
2360
+ transform : Transform | None = None , # noqa: F821
2361
+ info_keys : list [NestedKey ] | None = None ,
2350
2362
to_numpy : bool = False ,
2351
2363
reward_threshold : float | None = None ,
2352
2364
nondeterministic : bool = False ,
@@ -2393,8 +2405,8 @@ def _register_gym( # noqa: F811
2393
2405
cls ,
2394
2406
id ,
2395
2407
entry_point : Callable | None = None ,
2396
- transform : " Transform" | None = None , # noqa: F821
2397
- info_keys : List [NestedKey ] | None = None ,
2408
+ transform : Transform | None = None , # noqa: F821
2409
+ info_keys : list [NestedKey ] | None = None ,
2398
2410
to_numpy : bool = False ,
2399
2411
reward_threshold : float | None = None ,
2400
2412
nondeterministic : bool = False ,
@@ -2446,8 +2458,8 @@ def _register_gym( # noqa: F811
2446
2458
cls ,
2447
2459
id ,
2448
2460
entry_point : Callable | None = None ,
2449
- transform : " Transform" | None = None , # noqa: F821
2450
- info_keys : List [NestedKey ] | None = None ,
2461
+ transform : Transform | None = None , # noqa: F821
2462
+ info_keys : list [NestedKey ] | None = None ,
2451
2463
to_numpy : bool = False ,
2452
2464
reward_threshold : float | None = None ,
2453
2465
nondeterministic : bool = False ,
@@ -2502,8 +2514,8 @@ def _register_gym( # noqa: F811
2502
2514
cls ,
2503
2515
id ,
2504
2516
entry_point : Callable | None = None ,
2505
- transform : " Transform" | None = None , # noqa: F821
2506
- info_keys : List [NestedKey ] | None = None ,
2517
+ transform : Transform | None = None , # noqa: F821
2518
+ info_keys : list [NestedKey ] | None = None ,
2507
2519
to_numpy : bool = False ,
2508
2520
reward_threshold : float | None = None ,
2509
2521
nondeterministic : bool = False ,
@@ -2560,8 +2572,8 @@ def _register_gym( # noqa: F811
2560
2572
cls ,
2561
2573
id ,
2562
2574
entry_point : Callable | None = None ,
2563
- transform : " Transform" | None = None , # noqa: F821
2564
- info_keys : List [NestedKey ] | None = None ,
2575
+ transform : Transform | None = None , # noqa: F821
2576
+ info_keys : list [NestedKey ] | None = None ,
2565
2577
to_numpy : bool = False ,
2566
2578
reward_threshold : float | None = None ,
2567
2579
nondeterministic : bool = False ,
@@ -2618,7 +2630,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2618
2630
2619
2631
def reset (
2620
2632
self ,
2621
- tensordict : Optional [ TensorDictBase ] = None ,
2633
+ tensordict : TensorDictBase | None = None ,
2622
2634
** kwargs ,
2623
2635
) -> TensorDictBase :
2624
2636
"""Resets the environment.
@@ -2727,8 +2739,8 @@ def numel(self) -> int:
2727
2739
return prod (self .batch_size )
2728
2740
2729
2741
def set_seed (
2730
- self , seed : Optional [ int ] = None , static_seed : bool = False
2731
- ) -> Optional [ int ] :
2742
+ self , seed : int | None = None , static_seed : bool = False
2743
+ ) -> int | None :
2732
2744
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
2733
2745
2734
2746
Args:
@@ -2749,7 +2761,7 @@ def set_seed(
2749
2761
return seed
2750
2762
2751
2763
@abc .abstractmethod
2752
- def _set_seed (self , seed : Optional [ int ] ):
2764
+ def _set_seed (self , seed : int | None ):
2753
2765
raise NotImplementedError
2754
2766
2755
2767
def set_state (self ):
@@ -2764,7 +2776,26 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
2764
2776
f"got { tensordict .batch_size } and { self .batch_size } "
2765
2777
)
2766
2778
2767
- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
2779
+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2780
+ """Generates all possible actions from the action spec.
2781
+
2782
+ This only works in environments with fully discrete actions.
2783
+
2784
+ Args:
2785
+ tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2786
+ is called with this tensordict.
2787
+
2788
+ Returns:
2789
+ a tensordict object with the "action" entry updated with a batch of
2790
+ all possible actions. The actions are stacked together in the
2791
+ leading dimension.
2792
+ """
2793
+ if tensordict is not None :
2794
+ self .reset (tensordict )
2795
+
2796
+ return self .full_action_spec .enumerate (use_mask = True )
2797
+
2798
+ def rand_action (self , tensordict : TensorDictBase | None = None ):
2768
2799
"""Performs a random action given the action_spec attribute.
2769
2800
2770
2801
Args:
@@ -2798,7 +2829,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2798
2829
tensordict .update (r )
2799
2830
return tensordict
2800
2831
2801
- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2832
+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2802
2833
"""Performs a random step in the environment given the action_spec attribute.
2803
2834
2804
2835
Args:
@@ -2834,15 +2865,15 @@ def _has_dynamic_specs(self) -> bool:
2834
2865
def rollout (
2835
2866
self ,
2836
2867
max_steps : int ,
2837
- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2838
- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2868
+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2869
+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
2839
2870
* ,
2840
2871
auto_reset : bool = True ,
2841
2872
auto_cast_to_device : bool = False ,
2842
2873
break_when_any_done : bool | None = None ,
2843
2874
break_when_all_done : bool | None = None ,
2844
2875
return_contiguous : bool | None = False ,
2845
- tensordict : Optional [ TensorDictBase ] = None ,
2876
+ tensordict : TensorDictBase | None = None ,
2846
2877
set_truncated : bool = False ,
2847
2878
out = None ,
2848
2879
trust_policy : bool = False ,
@@ -3364,7 +3395,7 @@ def _rollout_nonstop(
3364
3395
3365
3396
def step_and_maybe_reset (
3366
3397
self , tensordict : TensorDictBase
3367
- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3398
+ ) -> tuple [TensorDictBase , TensorDictBase ]:
3368
3399
"""Runs a step in the environment and (partially) resets it if needed.
3369
3400
3370
3401
Args:
@@ -3472,7 +3503,7 @@ def empty_cache(self):
3472
3503
3473
3504
@property
3474
3505
@_cache_value
3475
- def reset_keys (self ) -> List [NestedKey ]:
3506
+ def reset_keys (self ) -> list [NestedKey ]:
3476
3507
"""Returns a list of reset keys.
3477
3508
3478
3509
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3629,14 +3660,14 @@ class _EnvWrapper(EnvBase):
3629
3660
"""
3630
3661
3631
3662
git_url : str = ""
3632
- available_envs : Dict [str , Any ] = {}
3663
+ available_envs : dict [str , Any ] = {}
3633
3664
libname : str = ""
3634
3665
3635
3666
def __init__ (
3636
3667
self ,
3637
3668
* args ,
3638
3669
device : DEVICE_TYPING = None ,
3639
- batch_size : Optional [ torch .Size ] = None ,
3670
+ batch_size : torch .Size | None = None ,
3640
3671
allow_done_after_reset : bool = False ,
3641
3672
spec_locked : bool = True ,
3642
3673
** kwargs ,
@@ -3685,7 +3716,7 @@ def _sync_device(self):
3685
3716
return sync_func
3686
3717
3687
3718
@abc .abstractmethod
3688
- def _check_kwargs (self , kwargs : Dict ):
3719
+ def _check_kwargs (self , kwargs : dict ):
3689
3720
raise NotImplementedError
3690
3721
3691
3722
def __getattr__ (self , attr : str ) -> Any :
@@ -3711,7 +3742,7 @@ def __getattr__(self, attr: str) -> Any:
3711
3742
)
3712
3743
3713
3744
@abc .abstractmethod
3714
- def _init_env (self ) -> Optional [ int ] :
3745
+ def _init_env (self ) -> int | None :
3715
3746
"""Runs all the necessary steps such that the environment is ready to use.
3716
3747
3717
3748
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3725,7 +3756,7 @@ def _init_env(self) -> Optional[int]:
3725
3756
raise NotImplementedError
3726
3757
3727
3758
@abc .abstractmethod
3728
- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3759
+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
3729
3760
"""Creates an environment from the target library and stores it with the `_env` attribute.
3730
3761
3731
3762
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3734,7 +3765,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3734
3765
raise NotImplementedError
3735
3766
3736
3767
@abc .abstractmethod
3737
- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3768
+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
3738
3769
raise NotImplementedError
3739
3770
3740
3771
def close (self ) -> None :
@@ -3748,7 +3779,7 @@ def close(self) -> None:
3748
3779
3749
3780
def make_tensordict (
3750
3781
env : _EnvWrapper ,
3751
- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3782
+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
3752
3783
) -> TensorDictBase :
3753
3784
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
3754
3785
0 commit comments