Skip to content

Commit c20a4f1

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 3854ea4 commit c20a4f1

File tree

1 file changed

+46
-50
lines changed

1 file changed

+46
-50
lines changed

torchrl/envs/common.py

Lines changed: 46 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def __init__(
476476
self,
477477
*,
478478
device: DEVICE_TYPING = None,
479-
batch_size: Optional[torch.Size] = None,
479+
batch_size: torch.Size | None = None,
480480
run_type_checks: bool = False,
481481
allow_done_after_reset: bool = False,
482482
spec_locked: bool = True,
@@ -587,10 +587,10 @@ def auto_specs_(
587587
policy: Callable[[TensorDictBase], TensorDictBase],
588588
*,
589589
tensordict: TensorDictBase | None = None,
590-
action_key: NestedKey | List[NestedKey] = "action",
591-
done_key: NestedKey | List[NestedKey] | None = None,
592-
observation_key: NestedKey | List[NestedKey] = "observation",
593-
reward_key: NestedKey | List[NestedKey] = "reward",
590+
action_key: NestedKey | list[NestedKey] = "action",
591+
done_key: NestedKey | list[NestedKey] | None = None,
592+
observation_key: NestedKey | list[NestedKey] = "observation",
593+
reward_key: NestedKey | list[NestedKey] = "reward",
594594
):
595595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596596
@@ -692,7 +692,7 @@ def auto_specs_(
692692
if full_action_spec is not None:
693693
self.full_action_spec = full_action_spec
694694
if full_done_spec is not None:
695-
self.full_done_specs = full_done_spec
695+
self.full_done_spec = full_done_spec
696696
if full_observation_spec is not None:
697697
self.full_observation_spec = full_observation_spec
698698
if full_reward_spec is not None:
@@ -704,8 +704,7 @@ def auto_specs_(
704704

705705
@wraps(check_env_specs_func)
706706
def check_env_specs(self, *args, **kwargs):
707-
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
708-
kwargs["return_contiguous"] = return_contiguous
707+
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
709708
return check_env_specs_func(self, *args, **kwargs)
710709

711710
check_env_specs.__doc__ = check_env_specs_func.__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850849

851850
def append_transform(
852851
self,
853-
transform: "Transform" # noqa: F821
854-
| Callable[[TensorDictBase], TensorDictBase],
852+
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
855853
) -> EnvBase:
856854
"""Returns a transformed environment where the callable/transform passed is applied.
857855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995993

996994
@property
997995
@_cache_value
998-
def action_keys(self) -> List[NestedKey]:
996+
def action_keys(self) -> list[NestedKey]:
999997
"""The action keys of an environment.
1000998
1001999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
10081006

10091007
@property
10101008
@_cache_value
1011-
def state_keys(self) -> List[NestedKey]:
1009+
def state_keys(self) -> list[NestedKey]:
10121010
"""The state keys of an environment.
10131011
10141012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
12051203
# Reward spec
12061204
@property
12071205
@_cache_value
1208-
def reward_keys(self) -> List[NestedKey]:
1206+
def reward_keys(self) -> list[NestedKey]:
12091207
"""The reward keys of an environment.
12101208
12111209
By default, there will only be one key named "reward".
@@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:
12171215

12181216
@property
12191217
@_cache_value
1220-
def observation_keys(self) -> List[NestedKey]:
1218+
def observation_keys(self) -> list[NestedKey]:
12211219
"""The observation keys of an environment.
12221220
12231221
By default, there will only be one key named "observation".
@@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
14161414
# done spec
14171415
@property
14181416
@_cache_value
1419-
def done_keys(self) -> List[NestedKey]:
1417+
def done_keys(self) -> list[NestedKey]:
14201418
"""The done keys of an environment.
14211419
14221420
By default, there will only be one key named "done".
@@ -2202,8 +2200,8 @@ def register_gym(
22022200
id: str,
22032201
*,
22042202
entry_point: Callable | None = None,
2205-
transform: "Transform" | None = None, # noqa: F821
2206-
info_keys: List[NestedKey] | None = None,
2203+
transform: Transform | None = None, # noqa: F821
2204+
info_keys: list[NestedKey] | None = None,
22072205
backend: str = None,
22082206
to_numpy: bool = False,
22092207
reward_threshold: float | None = None,
@@ -2392,8 +2390,8 @@ def _register_gym(
23922390
cls,
23932391
id,
23942392
entry_point: Callable | None = None,
2395-
transform: "Transform" | None = None, # noqa: F821
2396-
info_keys: List[NestedKey] | None = None,
2393+
transform: Transform | None = None, # noqa: F821
2394+
info_keys: list[NestedKey] | None = None,
23972395
to_numpy: bool = False,
23982396
reward_threshold: float | None = None,
23992397
nondeterministic: bool = False,
@@ -2434,8 +2432,8 @@ def _register_gym( # noqa: F811
24342432
cls,
24352433
id,
24362434
entry_point: Callable | None = None,
2437-
transform: "Transform" | None = None, # noqa: F821
2438-
info_keys: List[NestedKey] | None = None,
2435+
transform: Transform | None = None, # noqa: F821
2436+
info_keys: list[NestedKey] | None = None,
24392437
to_numpy: bool = False,
24402438
reward_threshold: float | None = None,
24412439
nondeterministic: bool = False,
@@ -2482,8 +2480,8 @@ def _register_gym( # noqa: F811
24822480
cls,
24832481
id,
24842482
entry_point: Callable | None = None,
2485-
transform: "Transform" | None = None, # noqa: F821
2486-
info_keys: List[NestedKey] | None = None,
2483+
transform: Transform | None = None, # noqa: F821
2484+
info_keys: list[NestedKey] | None = None,
24872485
to_numpy: bool = False,
24882486
reward_threshold: float | None = None,
24892487
nondeterministic: bool = False,
@@ -2535,8 +2533,8 @@ def _register_gym( # noqa: F811
25352533
cls,
25362534
id,
25372535
entry_point: Callable | None = None,
2538-
transform: "Transform" | None = None, # noqa: F821
2539-
info_keys: List[NestedKey] | None = None,
2536+
transform: Transform | None = None, # noqa: F821
2537+
info_keys: list[NestedKey] | None = None,
25402538
to_numpy: bool = False,
25412539
reward_threshold: float | None = None,
25422540
nondeterministic: bool = False,
@@ -2591,8 +2589,8 @@ def _register_gym( # noqa: F811
25912589
cls,
25922590
id,
25932591
entry_point: Callable | None = None,
2594-
transform: "Transform" | None = None, # noqa: F821
2595-
info_keys: List[NestedKey] | None = None,
2592+
transform: Transform | None = None, # noqa: F821
2593+
info_keys: list[NestedKey] | None = None,
25962594
to_numpy: bool = False,
25972595
reward_threshold: float | None = None,
25982596
nondeterministic: bool = False,
@@ -2649,8 +2647,8 @@ def _register_gym( # noqa: F811
26492647
cls,
26502648
id,
26512649
entry_point: Callable | None = None,
2652-
transform: "Transform" | None = None, # noqa: F821
2653-
info_keys: List[NestedKey] | None = None,
2650+
transform: Transform | None = None, # noqa: F821
2651+
info_keys: list[NestedKey] | None = None,
26542652
to_numpy: bool = False,
26552653
reward_threshold: float | None = None,
26562654
nondeterministic: bool = False,
@@ -2707,7 +2705,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
27072705

27082706
def reset(
27092707
self,
2710-
tensordict: Optional[TensorDictBase] = None,
2708+
tensordict: TensorDictBase | None = None,
27112709
**kwargs,
27122710
) -> TensorDictBase:
27132711
"""Resets the environment.
@@ -2816,8 +2814,8 @@ def numel(self) -> int:
28162814
return prod(self.batch_size)
28172815

28182816
def set_seed(
2819-
self, seed: Optional[int] = None, static_seed: bool = False
2820-
) -> Optional[int]:
2817+
self, seed: int | None = None, static_seed: bool = False
2818+
) -> int | None:
28212819
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
28222820
28232821
Args:
@@ -2838,7 +2836,7 @@ def set_seed(
28382836
return seed
28392837

28402838
@abc.abstractmethod
2841-
def _set_seed(self, seed: Optional[int]):
2839+
def _set_seed(self, seed: int | None):
28422840
raise NotImplementedError
28432841

28442842
def set_state(self):
@@ -2853,9 +2851,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
28532851
f"got {tensordict.batch_size} and {self.batch_size}"
28542852
)
28552853

2856-
def all_actions(
2857-
self, tensordict: Optional[TensorDictBase] = None
2858-
) -> TensorDictBase:
2854+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
28592855
"""Generates all possible actions from the action spec.
28602856
28612857
This only works in environments with fully discrete actions.
@@ -2874,7 +2870,7 @@ def all_actions(
28742870

28752871
return self.full_action_spec.enumerate(use_mask=True)
28762872

2877-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2873+
def rand_action(self, tensordict: TensorDictBase | None = None):
28782874
"""Performs a random action given the action_spec attribute.
28792875
28802876
Args:
@@ -2908,7 +2904,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
29082904
tensordict.update(r)
29092905
return tensordict
29102906

2911-
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
2907+
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
29122908
"""Performs a random step in the environment given the action_spec attribute.
29132909
29142910
Args:
@@ -2944,15 +2940,15 @@ def _has_dynamic_specs(self) -> bool:
29442940
def rollout(
29452941
self,
29462942
max_steps: int,
2947-
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
2948-
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
2943+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
2944+
callback: Callable[[TensorDictBase, ...], Any] | None = None,
29492945
*,
29502946
auto_reset: bool = True,
29512947
auto_cast_to_device: bool = False,
29522948
break_when_any_done: bool | None = None,
29532949
break_when_all_done: bool | None = None,
29542950
return_contiguous: bool | None = False,
2955-
tensordict: Optional[TensorDictBase] = None,
2951+
tensordict: TensorDictBase | None = None,
29562952
set_truncated: bool = False,
29572953
out=None,
29582954
trust_policy: bool = False,
@@ -3479,7 +3475,7 @@ def _rollout_nonstop(
34793475

34803476
def step_and_maybe_reset(
34813477
self, tensordict: TensorDictBase
3482-
) -> Tuple[TensorDictBase, TensorDictBase]:
3478+
) -> tuple[TensorDictBase, TensorDictBase]:
34833479
"""Runs a step in the environment and (partially) resets it if needed.
34843480
34853481
Args:
@@ -3600,7 +3596,7 @@ def empty_cache(self):
36003596

36013597
@property
36023598
@_cache_value
3603-
def reset_keys(self) -> List[NestedKey]:
3599+
def reset_keys(self) -> list[NestedKey]:
36043600
"""Returns a list of reset keys.
36053601
36063602
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3757,14 +3753,14 @@ class _EnvWrapper(EnvBase):
37573753
"""
37583754

37593755
git_url: str = ""
3760-
available_envs: Dict[str, Any] = {}
3756+
available_envs: dict[str, Any] = {}
37613757
libname: str = ""
37623758

37633759
def __init__(
37643760
self,
37653761
*args,
37663762
device: DEVICE_TYPING = None,
3767-
batch_size: Optional[torch.Size] = None,
3763+
batch_size: torch.Size | None = None,
37683764
allow_done_after_reset: bool = False,
37693765
spec_locked: bool = True,
37703766
**kwargs,
@@ -3813,7 +3809,7 @@ def _sync_device(self):
38133809
return sync_func
38143810

38153811
@abc.abstractmethod
3816-
def _check_kwargs(self, kwargs: Dict):
3812+
def _check_kwargs(self, kwargs: dict):
38173813
raise NotImplementedError
38183814

38193815
def __getattr__(self, attr: str) -> Any:
@@ -3839,7 +3835,7 @@ def __getattr__(self, attr: str) -> Any:
38393835
)
38403836

38413837
@abc.abstractmethod
3842-
def _init_env(self) -> Optional[int]:
3838+
def _init_env(self) -> int | None:
38433839
"""Runs all the necessary steps such that the environment is ready to use.
38443840
38453841
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3853,7 +3849,7 @@ def _init_env(self) -> Optional[int]:
38533849
raise NotImplementedError
38543850

38553851
@abc.abstractmethod
3856-
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3852+
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
38573853
"""Creates an environment from the target library and stores it with the `_env` attribute.
38583854
38593855
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3862,7 +3858,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
38623858
raise NotImplementedError
38633859

38643860
@abc.abstractmethod
3865-
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
3861+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
38663862
raise NotImplementedError
38673863

38683864
def close(self, *, raise_if_closed: bool = True) -> None:
@@ -3876,7 +3872,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:
38763872

38773873
def make_tensordict(
38783874
env: _EnvWrapper,
3879-
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
3875+
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
38803876
) -> TensorDictBase:
38813877
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
38823878

0 commit comments

Comments
 (0)