4
4
# LICENSE file in the root directory of this source tree.
5
5
import warnings
6
6
from types import ModuleType
7
- from typing import List , Optional , Sequence , Dict
7
+ from typing import List , Dict
8
8
9
9
import torch
10
10
from packaging import version
50
50
gym_version = version .parse (gym .__version__ )
51
51
if gym_version >= version .parse ("0.26.0" ):
52
52
from gym .wrappers .compatibility import EnvCompatibility
53
- try :
54
- import retro
55
-
56
- _has_retro = True
57
- except ImportError :
58
- _has_retro = False
59
53
60
- __all__ = ["GymWrapper" , "GymEnv" , "RetroEnv" ]
54
+ __all__ = ["GymWrapper" , "GymEnv" ]
61
55
62
56
63
57
def _gym_to_torchrl_spec_transform (spec , dtype = None , device = "cpu" ) -> TensorSpec :
@@ -79,12 +73,15 @@ def _gym_to_torchrl_spec_transform(spec, dtype=None, device="cpu") -> TensorSpec
79
73
dtype = dtype ,
80
74
device = device ,
81
75
)
82
- elif isinstance (spec , (dict , gym .spaces .dict .Dict )):
83
- spec = {
84
- "next_" + k : _gym_to_torchrl_spec_transform (spec [k ], device = device )
85
- for k in spec
86
- }
87
- return CompositeSpec (** spec )
76
+ elif isinstance (spec , (Dict ,)):
77
+ spec_out = {}
78
+ for k in spec .keys ():
79
+ spec_out ["next_" + k ] = _gym_to_torchrl_spec_transform (
80
+ spec [k ], device = device
81
+ )
82
+ return CompositeSpec (** spec_out )
83
+ elif isinstance (spec , gym .spaces .dict .Dict ):
84
+ return _gym_to_torchrl_spec_transform (spec .spaces , device = device )
88
85
else :
89
86
raise NotImplementedError (
90
87
f"spec of type { type (spec ).__name__ } is currently unaccounted for"
@@ -111,9 +108,12 @@ def _get_gym():
111
108
112
109
def _is_from_pixels (env ):
113
110
observation_spec = env .observation_space
114
- if isinstance (observation_spec , (Dict , gym . spaces . dict . Dict )):
111
+ if isinstance (observation_spec , (Dict ,)):
115
112
if "pixels" in set (observation_spec .keys ()):
116
113
return True
114
+ if isinstance (observation_spec , (gym .spaces .dict .Dict ,)):
115
+ if "pixels" in set (observation_spec .spaces .keys ()):
116
+ return True
117
117
elif (
118
118
isinstance (observation_spec , gym .spaces .Box )
119
119
and (observation_spec .low == 0 ).all ()
@@ -337,23 +337,3 @@ def _check_kwargs(self, kwargs: Dict):
337
337
338
338
def __repr__ (self ) -> str :
339
339
return f"{ self .__class__ .__name__ } (env={ self .env_name } , batch_size={ self .batch_size } , device={ self .device } )"
340
-
341
-
342
- def _get_retro_envs () -> Sequence :
343
- if not _has_retro :
344
- return tuple ()
345
- else :
346
- return retro .data .list_games ()
347
-
348
-
349
- def _get_retro () -> Optional [ModuleType ]:
350
- if _has_retro :
351
- return retro
352
- else :
353
- return None
354
-
355
-
356
- class RetroEnv (GymEnv ):
357
- available_envs = _get_retro_envs ()
358
- lib = "retro"
359
- lib = _get_retro ()
0 commit comments