@@ -476,7 +476,7 @@ def __init__(
476
476
self ,
477
477
* ,
478
478
device : DEVICE_TYPING = None ,
479
- batch_size : Optional [ torch .Size ] = None ,
479
+ batch_size : torch .Size | None = None ,
480
480
run_type_checks : bool = False ,
481
481
allow_done_after_reset : bool = False ,
482
482
spec_locked : bool = True ,
@@ -587,10 +587,10 @@ def auto_specs_(
587
587
policy : Callable [[TensorDictBase ], TensorDictBase ],
588
588
* ,
589
589
tensordict : TensorDictBase | None = None ,
590
- action_key : NestedKey | List [NestedKey ] = "action" ,
591
- done_key : NestedKey | List [NestedKey ] | None = None ,
592
- observation_key : NestedKey | List [NestedKey ] = "observation" ,
593
- reward_key : NestedKey | List [NestedKey ] = "reward" ,
590
+ action_key : NestedKey | list [NestedKey ] = "action" ,
591
+ done_key : NestedKey | list [NestedKey ] | None = None ,
592
+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
593
+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
594
594
):
595
595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596
596
@@ -692,7 +692,7 @@ def auto_specs_(
692
692
if full_action_spec is not None :
693
693
self .full_action_spec = full_action_spec
694
694
if full_done_spec is not None :
695
- self .full_done_specs = full_done_spec
695
+ self .full_done_spec = full_done_spec
696
696
if full_observation_spec is not None :
697
697
self .full_observation_spec = full_observation_spec
698
698
if full_reward_spec is not None :
@@ -704,8 +704,7 @@ def auto_specs_(
704
704
705
705
@wraps (check_env_specs_func )
706
706
def check_env_specs (self , * args , ** kwargs ):
707
- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
708
- kwargs ["return_contiguous" ] = return_contiguous
707
+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
709
708
return check_env_specs_func (self , * args , ** kwargs )
710
709
711
710
check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850
849
851
850
def append_transform (
852
851
self ,
853
- transform : "Transform" # noqa: F821
854
- | Callable [[TensorDictBase ], TensorDictBase ],
852
+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
855
853
) -> EnvBase :
856
854
"""Returns a transformed environment where the callable/transform passed is applied.
857
855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995
993
996
994
@property
997
995
@_cache_value
998
- def action_keys (self ) -> List [NestedKey ]:
996
+ def action_keys (self ) -> list [NestedKey ]:
999
997
"""The action keys of an environment.
1000
998
1001
999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
1008
1006
1009
1007
@property
1010
1008
@_cache_value
1011
- def state_keys (self ) -> List [NestedKey ]:
1009
+ def state_keys (self ) -> list [NestedKey ]:
1012
1010
"""The state keys of an environment.
1013
1011
1014
1012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
1205
1203
# Reward spec
1206
1204
@property
1207
1205
@_cache_value
1208
- def reward_keys (self ) -> List [NestedKey ]:
1206
+ def reward_keys (self ) -> list [NestedKey ]:
1209
1207
"""The reward keys of an environment.
1210
1208
1211
1209
By default, there will only be one key named "reward".
@@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:
1217
1215
1218
1216
@property
1219
1217
@_cache_value
1220
- def observation_keys (self ) -> List [NestedKey ]:
1218
+ def observation_keys (self ) -> list [NestedKey ]:
1221
1219
"""The observation keys of an environment.
1222
1220
1223
1221
By default, there will only be one key named "observation".
@@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
1416
1414
# done spec
1417
1415
@property
1418
1416
@_cache_value
1419
- def done_keys (self ) -> List [NestedKey ]:
1417
+ def done_keys (self ) -> list [NestedKey ]:
1420
1418
"""The done keys of an environment.
1421
1419
1422
1420
By default, there will only be one key named "done".
@@ -2202,8 +2200,8 @@ def register_gym(
2202
2200
id : str ,
2203
2201
* ,
2204
2202
entry_point : Callable | None = None ,
2205
- transform : " Transform" | None = None , # noqa: F821
2206
- info_keys : List [NestedKey ] | None = None ,
2203
+ transform : Transform | None = None , # noqa: F821
2204
+ info_keys : list [NestedKey ] | None = None ,
2207
2205
backend : str = None ,
2208
2206
to_numpy : bool = False ,
2209
2207
reward_threshold : float | None = None ,
@@ -2392,8 +2390,8 @@ def _register_gym(
2392
2390
cls ,
2393
2391
id ,
2394
2392
entry_point : Callable | None = None ,
2395
- transform : " Transform" | None = None , # noqa: F821
2396
- info_keys : List [NestedKey ] | None = None ,
2393
+ transform : Transform | None = None , # noqa: F821
2394
+ info_keys : list [NestedKey ] | None = None ,
2397
2395
to_numpy : bool = False ,
2398
2396
reward_threshold : float | None = None ,
2399
2397
nondeterministic : bool = False ,
@@ -2434,8 +2432,8 @@ def _register_gym( # noqa: F811
2434
2432
cls ,
2435
2433
id ,
2436
2434
entry_point : Callable | None = None ,
2437
- transform : " Transform" | None = None , # noqa: F821
2438
- info_keys : List [NestedKey ] | None = None ,
2435
+ transform : Transform | None = None , # noqa: F821
2436
+ info_keys : list [NestedKey ] | None = None ,
2439
2437
to_numpy : bool = False ,
2440
2438
reward_threshold : float | None = None ,
2441
2439
nondeterministic : bool = False ,
@@ -2482,8 +2480,8 @@ def _register_gym( # noqa: F811
2482
2480
cls ,
2483
2481
id ,
2484
2482
entry_point : Callable | None = None ,
2485
- transform : " Transform" | None = None , # noqa: F821
2486
- info_keys : List [NestedKey ] | None = None ,
2483
+ transform : Transform | None = None , # noqa: F821
2484
+ info_keys : list [NestedKey ] | None = None ,
2487
2485
to_numpy : bool = False ,
2488
2486
reward_threshold : float | None = None ,
2489
2487
nondeterministic : bool = False ,
@@ -2535,8 +2533,8 @@ def _register_gym( # noqa: F811
2535
2533
cls ,
2536
2534
id ,
2537
2535
entry_point : Callable | None = None ,
2538
- transform : " Transform" | None = None , # noqa: F821
2539
- info_keys : List [NestedKey ] | None = None ,
2536
+ transform : Transform | None = None , # noqa: F821
2537
+ info_keys : list [NestedKey ] | None = None ,
2540
2538
to_numpy : bool = False ,
2541
2539
reward_threshold : float | None = None ,
2542
2540
nondeterministic : bool = False ,
@@ -2591,8 +2589,8 @@ def _register_gym( # noqa: F811
2591
2589
cls ,
2592
2590
id ,
2593
2591
entry_point : Callable | None = None ,
2594
- transform : " Transform" | None = None , # noqa: F821
2595
- info_keys : List [NestedKey ] | None = None ,
2592
+ transform : Transform | None = None , # noqa: F821
2593
+ info_keys : list [NestedKey ] | None = None ,
2596
2594
to_numpy : bool = False ,
2597
2595
reward_threshold : float | None = None ,
2598
2596
nondeterministic : bool = False ,
@@ -2649,8 +2647,8 @@ def _register_gym( # noqa: F811
2649
2647
cls ,
2650
2648
id ,
2651
2649
entry_point : Callable | None = None ,
2652
- transform : " Transform" | None = None , # noqa: F821
2653
- info_keys : List [NestedKey ] | None = None ,
2650
+ transform : Transform | None = None , # noqa: F821
2651
+ info_keys : list [NestedKey ] | None = None ,
2654
2652
to_numpy : bool = False ,
2655
2653
reward_threshold : float | None = None ,
2656
2654
nondeterministic : bool = False ,
@@ -2707,7 +2705,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2707
2705
2708
2706
def reset (
2709
2707
self ,
2710
- tensordict : Optional [ TensorDictBase ] = None ,
2708
+ tensordict : TensorDictBase | None = None ,
2711
2709
** kwargs ,
2712
2710
) -> TensorDictBase :
2713
2711
"""Resets the environment.
@@ -2816,8 +2814,8 @@ def numel(self) -> int:
2816
2814
return prod (self .batch_size )
2817
2815
2818
2816
def set_seed (
2819
- self , seed : Optional [ int ] = None , static_seed : bool = False
2820
- ) -> Optional [ int ] :
2817
+ self , seed : int | None = None , static_seed : bool = False
2818
+ ) -> int | None :
2821
2819
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
2822
2820
2823
2821
Args:
@@ -2838,7 +2836,7 @@ def set_seed(
2838
2836
return seed
2839
2837
2840
2838
@abc .abstractmethod
2841
- def _set_seed (self , seed : Optional [ int ] ):
2839
+ def _set_seed (self , seed : int | None ):
2842
2840
raise NotImplementedError
2843
2841
2844
2842
def set_state (self ):
@@ -2853,9 +2851,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
2853
2851
f"got { tensordict .batch_size } and { self .batch_size } "
2854
2852
)
2855
2853
2856
- def all_actions (
2857
- self , tensordict : Optional [TensorDictBase ] = None
2858
- ) -> TensorDictBase :
2854
+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2859
2855
"""Generates all possible actions from the action spec.
2860
2856
2861
2857
This only works in environments with fully discrete actions.
@@ -2874,7 +2870,7 @@ def all_actions(
2874
2870
2875
2871
return self .full_action_spec .enumerate (use_mask = True )
2876
2872
2877
- def rand_action (self , tensordict : Optional [ TensorDictBase ] = None ):
2873
+ def rand_action (self , tensordict : TensorDictBase | None = None ):
2878
2874
"""Performs a random action given the action_spec attribute.
2879
2875
2880
2876
Args:
@@ -2908,7 +2904,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2908
2904
tensordict .update (r )
2909
2905
return tensordict
2910
2906
2911
- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2907
+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2912
2908
"""Performs a random step in the environment given the action_spec attribute.
2913
2909
2914
2910
Args:
@@ -2944,15 +2940,15 @@ def _has_dynamic_specs(self) -> bool:
2944
2940
def rollout (
2945
2941
self ,
2946
2942
max_steps : int ,
2947
- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2948
- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2943
+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2944
+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
2949
2945
* ,
2950
2946
auto_reset : bool = True ,
2951
2947
auto_cast_to_device : bool = False ,
2952
2948
break_when_any_done : bool | None = None ,
2953
2949
break_when_all_done : bool | None = None ,
2954
2950
return_contiguous : bool | None = False ,
2955
- tensordict : Optional [ TensorDictBase ] = None ,
2951
+ tensordict : TensorDictBase | None = None ,
2956
2952
set_truncated : bool = False ,
2957
2953
out = None ,
2958
2954
trust_policy : bool = False ,
@@ -3479,7 +3475,7 @@ def _rollout_nonstop(
3479
3475
3480
3476
def step_and_maybe_reset (
3481
3477
self , tensordict : TensorDictBase
3482
- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3478
+ ) -> tuple [TensorDictBase , TensorDictBase ]:
3483
3479
"""Runs a step in the environment and (partially) resets it if needed.
3484
3480
3485
3481
Args:
@@ -3600,7 +3596,7 @@ def empty_cache(self):
3600
3596
3601
3597
@property
3602
3598
@_cache_value
3603
- def reset_keys (self ) -> List [NestedKey ]:
3599
+ def reset_keys (self ) -> list [NestedKey ]:
3604
3600
"""Returns a list of reset keys.
3605
3601
3606
3602
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3757,14 +3753,14 @@ class _EnvWrapper(EnvBase):
3757
3753
"""
3758
3754
3759
3755
git_url : str = ""
3760
- available_envs : Dict [str , Any ] = {}
3756
+ available_envs : dict [str , Any ] = {}
3761
3757
libname : str = ""
3762
3758
3763
3759
def __init__ (
3764
3760
self ,
3765
3761
* args ,
3766
3762
device : DEVICE_TYPING = None ,
3767
- batch_size : Optional [ torch .Size ] = None ,
3763
+ batch_size : torch .Size | None = None ,
3768
3764
allow_done_after_reset : bool = False ,
3769
3765
spec_locked : bool = True ,
3770
3766
** kwargs ,
@@ -3813,7 +3809,7 @@ def _sync_device(self):
3813
3809
return sync_func
3814
3810
3815
3811
@abc .abstractmethod
3816
- def _check_kwargs (self , kwargs : Dict ):
3812
+ def _check_kwargs (self , kwargs : dict ):
3817
3813
raise NotImplementedError
3818
3814
3819
3815
def __getattr__ (self , attr : str ) -> Any :
@@ -3839,7 +3835,7 @@ def __getattr__(self, attr: str) -> Any:
3839
3835
)
3840
3836
3841
3837
@abc .abstractmethod
3842
- def _init_env (self ) -> Optional [ int ] :
3838
+ def _init_env (self ) -> int | None :
3843
3839
"""Runs all the necessary steps such that the environment is ready to use.
3844
3840
3845
3841
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3853,7 +3849,7 @@ def _init_env(self) -> Optional[int]:
3853
3849
raise NotImplementedError
3854
3850
3855
3851
@abc .abstractmethod
3856
- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3852
+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
3857
3853
"""Creates an environment from the target library and stores it with the `_env` attribute.
3858
3854
3859
3855
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3862,7 +3858,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3862
3858
raise NotImplementedError
3863
3859
3864
3860
@abc .abstractmethod
3865
- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3861
+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
3866
3862
raise NotImplementedError
3867
3863
3868
3864
def close (self , * , raise_if_closed : bool = True ) -> None :
@@ -3876,7 +3872,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:
3876
3872
3877
3873
def make_tensordict (
3878
3874
env : _EnvWrapper ,
3879
- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3875
+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
3880
3876
) -> TensorDictBase :
3881
3877
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
3882
3878
0 commit comments