Skip to content

Commit b583ac1

Browse files
authored
[Refactoring] Replace direct gym version checks with decorated functions (#691)
* [Refactoring] Replace gym version checking with decorated functions (#) Initial commit. Only tests. * Refactoring in gym.py * More refactoring in gym.py * Completed refactoring * amend * amend
1 parent a3bbba0 commit b583ac1

File tree

7 files changed

+162
-154
lines changed

7 files changed

+162
-154
lines changed

test/_utils_internal.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,43 @@
1111
# this returns relative path from current file.
1212
import pytest
1313
import torch.cuda
14-
from torchrl._utils import seed_generator
14+
from torchrl._utils import seed_generator, implement_for
1515
from torchrl.envs import EnvBase
16-
16+
from torchrl.envs.libs.gym import _has_gym
1717

1818
# Specified for test_utils.py
1919
__version__ = "0.3"
2020

21+
# Default versions of the environments.
22+
CARTPOLE_VERSIONED = "CartPole-v1"
23+
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
24+
PENDULUM_VERSIONED = "Pendulum-v1"
25+
PONG_VERSIONED = "ALE/Pong-v5"
26+
27+
28+
@implement_for("gym", None, "0.21.0")
29+
def _set_gym_environments(): # noqa: F811
30+
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
31+
32+
CARTPOLE_VERSIONED = "CartPole-v0"
33+
HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
34+
PENDULUM_VERSIONED = "Pendulum-v0"
35+
PONG_VERSIONED = "Pong-v4"
36+
37+
38+
@implement_for("gym", "0.21.0", None)
39+
def _set_gym_environments(): # noqa: F811
40+
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
41+
42+
CARTPOLE_VERSIONED = "CartPole-v1"
43+
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
44+
PENDULUM_VERSIONED = "Pendulum-v1"
45+
PONG_VERSIONED = "ALE/Pong-v5"
46+
47+
48+
if _has_gym:
49+
_set_gym_environments()
50+
2151

2252
def get_relative_path(curr_file, *path_components):
2353
return os.path.join(os.path.dirname(curr_file), *path_components)

test/smoke_test_deps.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,10 @@
22
import tempfile
33

44
import pytest
5+
from _utils_internal import PONG_VERSIONED
56
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
67
from torchrl.envs.libs.gym import _has_gym, GymEnv
78

8-
if _has_gym:
9-
import gym
10-
from packaging import version
11-
12-
gym_version = version.parse(gym.__version__)
13-
PONG_VERSIONED = (
14-
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
15-
)
16-
else:
17-
# placeholders
18-
PONG_VERSIONED = "ALE/Pong-v5"
19-
209
try:
2110
from torch.utils.tensorboard import SummaryWriter
2211

test/test_collector.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pytest
1010
import torch
11-
from _utils_internal import generate_seeds
11+
from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED
1212
from mocking_classes import (
1313
ContinuousActionVecMockEnv,
1414
DiscreteActionConvMockEnv,
@@ -42,22 +42,6 @@
4242
TensorDictModule,
4343
)
4444

45-
if _has_gym:
46-
import gym
47-
from packaging import version
48-
49-
gym_version = version.parse(gym.__version__)
50-
PENDULUM_VERSIONED = (
51-
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
52-
)
53-
PONG_VERSIONED = (
54-
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
55-
)
56-
else:
57-
# placeholders
58-
PENDULUM_VERSIONED = "Pendulum-v1"
59-
PONG_VERSIONED = "ALE/Pong-v5"
60-
6145
# torch.set_default_dtype(torch.double)
6246

6347

test/test_env.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
import pytest
1212
import torch
1313
import yaml
14-
from _utils_internal import get_available_devices
14+
from _utils_internal import (
15+
get_available_devices,
16+
CARTPOLE_VERSIONED,
17+
PENDULUM_VERSIONED,
18+
PONG_VERSIONED,
19+
HALFCHEETAH_VERSIONED,
20+
)
1521
from mocking_classes import (
1622
ActionObsMergeLinear,
1723
DiscreteActionConvMockEnv,
@@ -49,30 +55,11 @@
4955
)
5056
from torchrl.modules.tensordict_module import WorldModelWrapper
5157

58+
gym_version = None
5259
if _has_gym:
5360
import gym
5461

5562
gym_version = version.parse(gym.__version__)
56-
PENDULUM_VERSIONED = (
57-
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
58-
)
59-
CARTPOLE_VERSIONED = (
60-
"CartPole-v1" if gym_version > version.parse("0.20.0") else "CartPole-v0"
61-
)
62-
PONG_VERSIONED = (
63-
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
64-
)
65-
HALFCHEETAH_VERSIONED = (
66-
"HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2"
67-
)
68-
else:
69-
# placeholder
70-
gym_version = version.parse("0.0.1")
71-
72-
# placeholders
73-
PENDULUM_VERSIONED = "Pendulum-v1"
74-
CARTPOLE_VERSIONED = "CartPole-v1"
75-
PONG_VERSIONED = "ALE/Pong-v5"
7663

7764
try:
7865
this_dir = os.path.dirname(os.path.realpath(__file__))
@@ -1048,7 +1035,7 @@ def test_batch_unlocked_with_batch_size(device):
10481035

10491036
@pytest.mark.skipif(not _has_gym, reason="no gym")
10501037
@pytest.mark.skipif(
1051-
gym_version < version.parse("0.20.0"),
1038+
gym_version is None or gym_version < version.parse("0.20.0"),
10521039
reason="older versions of half-cheetah do not have 'x_position' info key.",
10531040
)
10541041
def test_info_dict_reader(seed=0):

test/test_libs.py

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,27 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
from sys import platform
67

78
import numpy as np
89
import pytest
910
import torch
10-
from _utils_internal import _test_fake_tensordict
11-
from _utils_internal import get_available_devices
11+
from _utils_internal import (
12+
_test_fake_tensordict,
13+
get_available_devices,
14+
HALFCHEETAH_VERSIONED,
15+
PONG_VERSIONED,
16+
PENDULUM_VERSIONED,
17+
)
1218
from packaging import version
19+
from tensordict.tensordict import assert_allclose_td
20+
from torchrl._utils import implement_for
1321
from torchrl.collectors import MultiaSyncDataCollector
1422
from torchrl.collectors.collectors import RandomPolicy
23+
from torchrl.envs import EnvCreator, ParallelEnv
24+
from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper
1525
from torchrl.envs.libs.dm_control import _has_dmc
26+
from torchrl.envs.libs.gym import GymEnv, GymWrapper
1627
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels
1728
from torchrl.envs.libs.habitat import HabitatEnv, _has_habitat
1829
from torchrl.envs.libs.jumanji import JumanjiEnv, _has_jumanji
@@ -32,39 +43,8 @@
3243
from dm_control import suite
3344
from dm_control.suite.wrappers import pixels
3445

35-
from sys import platform
36-
37-
from tensordict.tensordict import assert_allclose_td
38-
from torchrl.envs import EnvCreator, ParallelEnv
39-
from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper
40-
from torchrl.envs.libs.gym import GymEnv, GymWrapper
41-
4246
IS_OSX = platform == "darwin"
4347

44-
if _has_gym:
45-
from packaging import version
46-
47-
gym_version = version.parse(gym.__version__)
48-
PENDULUM_VERSIONED = (
49-
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
50-
)
51-
HC_VERSIONED = (
52-
"HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2"
53-
)
54-
PONG_VERSIONED = (
55-
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
56-
)
57-
58-
# if gym_version < version.parse("0.24.0") and torch.cuda.device_count() > 0:
59-
# from opengl_rendering import create_opengl_context
60-
#
61-
# create_opengl_context()
62-
else:
63-
# placeholders
64-
PENDULUM_VERSIONED = "Pendulum-v1"
65-
HC_VERSIONED = "HalfCheetah-v4"
66-
PONG_VERSIONED = "ALE/Pong-v5"
67-
6848

6949
@pytest.mark.skipif(not _has_gym, reason="no gym library found")
7050
@pytest.mark.parametrize(
@@ -123,10 +103,7 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
123103
base_env = gym.make(env_name, frameskip=frame_skip)
124104
frame_skip = 1
125105
else:
126-
if gym_version < version.parse("0.26.0"):
127-
base_env = gym.make(env_name)
128-
else:
129-
base_env = gym.make(env_name, render_mode="rgb_array")
106+
base_env = _make_gym_environment(env_name)
130107

131108
if from_pixels and not _is_from_pixels(base_env):
132109
base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only)
@@ -164,6 +141,16 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
164141
_test_fake_tensordict(env)
165142

166143

144+
@implement_for("gym", None, "0.26")
145+
def _make_gym_environment(env_name): # noqa: F811
146+
return gym.make(env_name)
147+
148+
149+
@implement_for("gym", "0.26", None)
150+
def _make_gym_environment(env_name): # noqa: F811
151+
return gym.make(env_name, render_mode="rgb_array")
152+
153+
167154
@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found")
168155
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
169156
@pytest.mark.parametrize("frame_skip", [1, 3])
@@ -270,9 +257,9 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
270257
"env_lib,env_args,env_kwargs",
271258
[
272259
[DMControlEnv, ("cheetah", "run"), {"from_pixels": True}],
273-
[GymEnv, (HC_VERSIONED,), {"from_pixels": True}],
260+
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}],
274261
[DMControlEnv, ("cheetah", "run"), {"from_pixels": False}],
275-
[GymEnv, (HC_VERSIONED,), {"from_pixels": False}],
262+
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}],
276263
[GymEnv, (PONG_VERSIONED,), {}],
277264
],
278265
)
@@ -307,9 +294,9 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs):
307294
"env_lib,env_args,env_kwargs",
308295
[
309296
[DMControlEnv, ("cheetah", "run"), {"from_pixels": True}],
310-
[GymEnv, (HC_VERSIONED,), {"from_pixels": True}],
297+
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}],
311298
[DMControlEnv, ("cheetah", "run"), {"from_pixels": False}],
312-
[GymEnv, (HC_VERSIONED,), {"from_pixels": False}],
299+
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}],
313300
[GymEnv, (PONG_VERSIONED,), {}],
314301
],
315302
)

test/test_transforms.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import numpy as np
99
import pytest
1010
import torch
11-
from _utils_internal import get_available_devices, retry, dtype_fixture # noqa
11+
from _utils_internal import ( # noqa
12+
get_available_devices,
13+
retry,
14+
dtype_fixture,
15+
PENDULUM_VERSIONED,
16+
)
1217
from mocking_classes import (
1318
ContinuousActionVecMockEnv,
1419
DiscreteActionConvMockEnvNumpy,
@@ -59,18 +64,6 @@
5964
)
6065
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
6166

62-
if _has_gym:
63-
import gym
64-
from packaging import version
65-
66-
gym_version = version.parse(gym.__version__)
67-
PENDULUM_VERSIONED = (
68-
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
69-
)
70-
else:
71-
# placeholders
72-
PENDULUM_VERSIONED = "Pendulum-v1"
73-
7467
TIMEOUT = 10.0
7568

7669

0 commit comments

Comments
 (0)