Skip to content

Commit ee7a498

Browse files
authored
[Major, BugFix, Test] Refactor Transforms tests (#878)
1 parent fd4634f commit ee7a498

27 files changed

+6345
-2358
lines changed

.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,87 @@ conda activate ./env
1414
$DIR/install.sh
1515

1616
# Extracted from run_test.sh to run once.
17-
yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y
17+
#yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y
18+
yum makecache && yum install libglvnd-devel glew xorg-x11-server-Xvfb zlib-devel egl-utils mesa-libEGL -y
19+
1820

1921
# This version is installed initially (see environment.yml)
2022
for GYM_VERSION in '0.13'
2123
do
24+
# Create a copy of the conda env and work with this
25+
conda deactivate
26+
conda create --prefix ./cloned_env --clone ./env -y
27+
2228
echo "Testing gym version: ${GYM_VERSION}"
2329
pip3 install 'gym[atari]'==$GYM_VERSION
2430
$DIR/run_test.sh
31+
32+
# delete the conda copy
33+
conda deactivate
34+
conda env remove -n ./cloned_env
2535
done
2636

2737
# gym[atari]==0.19 is broken, so we install only gym without dependencies.
2838
for GYM_VERSION in '0.19'
2939
do
40+
# Create a copy of the conda env and work with this
41+
conda deactivate
42+
conda create --prefix ./cloned_env --clone ./env -y
43+
3044
echo "Testing gym version: ${GYM_VERSION}"
3145
pip3 install gym==$GYM_VERSION
3246
$DIR/run_test.sh
47+
48+
# delete the conda copy
49+
conda deactivate
50+
conda env remove -n ./cloned_env
3351
done
3452

3553
# gym[atari]==0.20 installs ale-py==0.8, but this version is not compatible with gym<0.26, so we downgrade it.
3654
for GYM_VERSION in '0.20'
3755
do
56+
# Create a copy of the conda env and work with this
57+
conda deactivate
58+
conda create --prefix ./cloned_env --clone ./env -y
59+
3860
echo "Testing gym version: ${GYM_VERSION}"
3961
pip3 install 'gym[atari]'==$GYM_VERSION
4062
pip3 install ale-py==0.7
4163
$DIR/run_test.sh
64+
65+
# delete the conda copy
66+
conda deactivate
67+
conda env remove -n ./cloned_env
4268
done
4369

4470
for GYM_VERSION in '0.25'
4571
do
72+
# Create a copy of the conda env and work with this
73+
conda deactivate
74+
conda create --prefix ./cloned_env --clone ./env -y
75+
4676
echo "Testing gym version: ${GYM_VERSION}"
4777
pip3 install 'gym[atari]'==$GYM_VERSION
4878
$DIR/run_test.sh
79+
80+
# delete the conda copy
81+
conda deactivate
82+
conda env remove -n ./cloned_env
4983
done
5084

5185
# For this version "gym[accept-rom-license]" is required.
5286
for GYM_VERSION in '0.26'
5387
do
88+
# Create a copy of the conda env and work with this
89+
conda deactivate
90+
conda create --prefix ./cloned_env --clone ./env -y
91+
5492
echo "Testing gym version: ${GYM_VERSION}"
5593
pip3 install 'gym[accept-rom-license]'==$GYM_VERSION
5694
pip3 install 'gym[atari]'==$GYM_VERSION
5795
$DIR/run_test.sh
96+
97+
# delete the conda copy
98+
conda deactivate
99+
conda env remove -n ./cloned_env
58100
done

.circleci/unittest/linux_libs/scripts_gym/run_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ set -e
55
eval "$(./conda/bin/conda shell.bash hook)"
66
conda activate ./env
77

8-
yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y
8+
#yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y
9+
#yum makecache && yum install glew -y
910

1011
export PYTORCH_TEST_WITH_SLOW='1'
1112
python -m torch.utils.collect_env
@@ -22,5 +23,6 @@ export MKL_THREADING_LAYER=GNU
2223

2324
coverage run -m pytest test/smoke_test.py -v --durations 20
2425
coverage run -m pytest test/smoke_test_deps.py -v --durations 20 -k 'test_gym or test_dm_control_pixels or test_dm_control'
26+
MUJOCO_GL=egl python test/test_libs.py -k test_collect
2527
MUJOCO_GL=egl coverage run -m pytest --instafail -v --durations 20 -k 'test_libs'
2628
coverage xml -i

.circleci/unittest/linux_libs/scripts_gym/setup_env.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ wget http://www.atarimania.com/roms/Roms.rar
9494
./rar/unrar e Roms.rar ./Roms -y
9595
python -m atari_py.import_roms Roms
9696

97-
yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL glfw mesa-libOSMesa-devel glew egl-utils freeglut -y
97+
#yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL glfw mesa-libOSMesa-devel glew egl-utils freeglut -y
98+
yum makecache && yum install libglvnd-devel glew zlib-devel -y
9899

99100
# install mujoco-py locally
100101
cd ${root_dir}/.mujoco/mujoco-py

examples/dreamer/dreamer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def make_env_transforms(
8989
env.append_transform(Resize(cfg.image_size, cfg.image_size))
9090
if cfg.grayscale:
9191
env.append_transform(GrayScale())
92-
env.append_transform(FlattenObservation(0, -3))
92+
env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
9393
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
9494
if stats is None:
9595
obs_stats = {

test/_utils_internal.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def dtype_fixture():
120120
torch.set_default_dtype(dtype)
121121

122122

123+
@contextlib.contextmanager
124+
def set_global_var(module, var_name, value):
125+
old_value = getattr(module, var_name)
126+
setattr(module, var_name, value)
127+
try:
128+
yield
129+
finally:
130+
setattr(module, var_name, old_value)
131+
132+
123133
def _make_envs(
124134
env_name,
125135
frame_skip,
@@ -263,13 +273,3 @@ def t_out():
263273
)
264274

265275
return t_out
266-
267-
268-
@contextlib.contextmanager
269-
def set_global_var(module, var_name, value):
270-
old_value = getattr(module, var_name)
271-
setattr(module, var_name, value)
272-
try:
273-
yield
274-
finally:
275-
setattr(module, var_name, old_value)

test/mocking_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def __new__(
360360
if categorical_action_encoding
361361
else OneHotDiscreteTensorSpec
362362
)
363-
action_spec = action_spec_cls(*batch_size, 7)
363+
action_spec = action_spec_cls(n=7, shape=(*batch_size, 7))
364364
if reward_spec is None:
365365
reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
366366

test/test_collector.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _is_consistent_device_type(
149149
_os_is_windows and _python_is_3_10,
150150
reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection",
151151
)
152-
@pytest.mark.parametrize("num_env", [1, 3])
152+
@pytest.mark.parametrize("num_env", [1, 2])
153153
@pytest.mark.parametrize("device", ["cuda", "cpu", None])
154154
@pytest.mark.parametrize("policy_device", ["cuda", "cpu", None])
155155
@pytest.mark.parametrize("passing_device", ["cuda", "cpu", None])
@@ -233,7 +233,7 @@ def env_fn(seed):
233233
ccollector.shutdown()
234234

235235

236-
@pytest.mark.parametrize("num_env", [1, 3])
236+
@pytest.mark.parametrize("num_env", [1, 2])
237237
@pytest.mark.parametrize("env_name", ["conv", "vec"])
238238
def test_concurrent_collector_consistency(num_env, env_name, seed=40):
239239
if num_env == 1:
@@ -309,7 +309,7 @@ def make_env():
309309
return GymEnv(PONG_VERSIONED, frame_skip=4)
310310

311311
env = SerialEnv(2, make_env)
312-
# env = SerialEnv(3, lambda: GymEnv("CartPole-v1", frame_skip=4))
312+
# env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4))
313313
env.set_seed(0)
314314
collector = SyncDataCollector(
315315
env, total_frames=10000, frames_per_batch=10000, split_trajs=False
@@ -331,7 +331,7 @@ def make_env():
331331
assert _data["reward"].sum(-2).min() == -21
332332

333333

334-
@pytest.mark.parametrize("num_env", [1, 3])
334+
@pytest.mark.parametrize("num_env", [1, 2])
335335
@pytest.mark.parametrize("env_name", ["vec"])
336336
def test_collector_done_persist(num_env, env_name, seed=5):
337337
if num_env == 1:
@@ -381,7 +381,7 @@ def make_env(seed):
381381

382382

383383
@pytest.mark.parametrize("frames_per_batch", [200, 10])
384-
@pytest.mark.parametrize("num_env", [1, 3])
384+
@pytest.mark.parametrize("num_env", [1, 2])
385385
@pytest.mark.parametrize("env_name", ["vec"])
386386
def test_split_trajs(num_env, env_name, frames_per_batch, seed=5):
387387
if num_env == 1:
@@ -475,7 +475,7 @@ def make_env(seed):
475475
# ccollector.shutdown()
476476

477477

478-
@pytest.mark.parametrize("num_env", [1, 3])
478+
@pytest.mark.parametrize("num_env", [1, 2])
479479
@pytest.mark.parametrize("env_name", ["vec", "conv"])
480480
def test_collector_batch_size(num_env, env_name, seed=100):
481481
if num_env == 3 and _os_is_windows:
@@ -498,7 +498,7 @@ def env_fn():
498498

499499
torch.manual_seed(0)
500500
np.random.seed(0)
501-
num_workers = 4
501+
num_workers = 2
502502
frames_per_batch = 20
503503
ccollector = MultiaSyncDataCollector(
504504
create_env_fn=[env_fn for _ in range(num_workers)],
@@ -534,7 +534,7 @@ def env_fn():
534534
ccollector.shutdown()
535535

536536

537-
@pytest.mark.parametrize("num_env", [1, 3])
537+
@pytest.mark.parametrize("num_env", [1, 2])
538538
@pytest.mark.parametrize("env_name", ["vec", "conv"])
539539
def test_concurrent_collector_seed(num_env, env_name, seed=100):
540540
if num_env == 1:
@@ -581,7 +581,7 @@ def env_fn():
581581
ccollector.shutdown()
582582

583583

584-
@pytest.mark.parametrize("num_env", [1, 3])
584+
@pytest.mark.parametrize("num_env", [1, 2])
585585
@pytest.mark.parametrize("env_name", ["conv", "vec"])
586586
def test_collector_consistency(num_env, env_name, seed=100):
587587
if num_env == 1:
@@ -644,7 +644,7 @@ def env_fn(seed):
644644
collector.shutdown()
645645

646646

647-
@pytest.mark.parametrize("num_env", [1, 3])
647+
@pytest.mark.parametrize("num_env", [1, 2])
648648
@pytest.mark.parametrize("collector_class", [SyncDataCollector, aSyncDataCollector])
649649
@pytest.mark.parametrize("env_name", ["conv", "vec"])
650650
def test_traj_len_consistency(num_env, env_name, collector_class, seed=100):
@@ -1100,7 +1100,7 @@ def env_fn(seed):
11001100
],
11011101
)
11021102
class TestAutoWrap:
1103-
num_envs = 3
1103+
num_envs = 2
11041104

11051105
@pytest.fixture
11061106
def env_maker(self):

test/test_env.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from torchrl.envs.gym_like import default_info_dict_reader
4141
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
4242
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
43-
from torchrl.envs.transforms import Compose, TransformedEnv
43+
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
4444
from torchrl.envs.utils import step_mdp
4545
from torchrl.envs.vec_env import ParallelEnv, SerialEnv
4646
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
@@ -195,6 +195,38 @@ def test_rollout_predictability(device):
195195
).all()
196196

197197

198+
@pytest.mark.skipif(not _has_gym, reason="no gym")
199+
@pytest.mark.parametrize(
200+
"env_name",
201+
[
202+
PENDULUM_VERSIONED,
203+
],
204+
)
205+
@pytest.mark.parametrize(
206+
"frame_skip",
207+
[
208+
1,
209+
],
210+
)
211+
@pytest.mark.parametrize("parallel", [False, True])
212+
def test_rollout_reset(env_name, frame_skip, parallel, seed=0):
213+
envs = []
214+
for horizon in [20, 30, 40]:
215+
envs.append(
216+
lambda horizon=horizon: TransformedEnv(
217+
GymEnv(env_name, frame_skip=frame_skip), StepCounter(horizon)
218+
)
219+
)
220+
if parallel:
221+
env = ParallelEnv(3, envs)
222+
else:
223+
env = SerialEnv(3, envs)
224+
env.set_seed(100)
225+
out = env.rollout(100, break_when_any_done=False)
226+
assert out.shape == torch.Size([3, 100])
227+
assert (out["done"].squeeze().sum(-1) == torch.tensor([5, 3, 2])).all()
228+
229+
198230
class TestModelBasedEnvBase:
199231
@pytest.mark.parametrize("device", get_available_devices())
200232
def test_mb_rollout(self, device, seed=0):
@@ -585,9 +617,9 @@ def test_parallel_env_custom_method(self, parallel):
585617
# define env
586618

587619
if parallel:
588-
env = ParallelEnv(3, lambda: DiscreteActionVecMockEnv())
620+
env = ParallelEnv(2, lambda: DiscreteActionVecMockEnv())
589621
else:
590-
env = SerialEnv(3, lambda: DiscreteActionVecMockEnv())
622+
env = SerialEnv(2, lambda: DiscreteActionVecMockEnv())
591623

592624
# we must start the environment first
593625
env.reset()

test/test_exploration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_gsde(
265265
device=device,
266266
)
267267
if gSDE:
268-
gSDENoise().reset(td)
268+
gSDENoise(shape=[batch]).reset(td)
269269
assert "_eps_gSDE" in td.keys()
270270
assert td.get("_eps_gSDE").device == device
271271
actor(td)

0 commit comments

Comments
 (0)