Skip to content

Commit 9248c39

Browse files
author
Vincent Moens
committed
[CI] Fix envnames in SOTA tests
ghstack-source-id: 3b518e2 Pull-Request-resolved: #2921
1 parent 425952b commit 9248c39

19 files changed

+143
-63
lines changed

.github/unittest/linux_sota/scripts/environment.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,3 @@ dependencies:
2929
- coverage
3030
- vmas
3131
- transformers
32-
- gym[atari]
33-
- gym[accept-rom-license]

.github/unittest/linux_sota/scripts/run_all.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ python -c """import gym;import d4rl"""
111111

112112
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
113113
# rename them
114-
pip install "gymnasium[atari]>=1.1.0"
115114

116115
# ============================================================================================ #
117116
# ================================ PyTorch & TorchRL ========================================= #
@@ -128,6 +127,9 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
128127
# submodules
129128
git submodule sync && git submodule update --init --recursive
130129

130+
pip3 install ale-py -U
131+
pip3 install "gym[atari,accept-rom-license]" "gymnasium>=1.1.0" -U
132+
131133
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
132134
if [[ "$TORCH_VERSION" == "nightly" ]]; then
133135
if [ "${CU_VERSION:-}" == cpu ] ; then

.github/unittest/linux_sota/scripts/test_sota.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
collector.frames_per_batch=20 \
4040
collector.num_workers=1 \
4141
logger.backend= \
42+
env.backend=gym \
4243
logger.test_interval=10
4344
""",
4445
"ppo_mujoco": """python sota-implementations/ppo/ppo_mujoco.py \
@@ -56,6 +57,7 @@
5657
loss.mini_batch_size=20 \
5758
loss.ppo_epochs=2 \
5859
logger.backend= \
60+
env.backend=gym \
5961
logger.test_interval=10
6062
""",
6163
"ddpg": """python sota-implementations/ddpg/ddpg.py \
@@ -82,6 +84,7 @@
8284
collector.frames_per_batch=20 \
8385
loss.mini_batch_size=20 \
8486
logger.backend= \
87+
env.backend=gym \
8588
logger.test_interval=40
8689
""",
8790
"dqn_atari": """python sota-implementations/dqn/dqn_atari.py \
@@ -91,6 +94,7 @@
9194
buffer.batch_size=10 \
9295
loss.num_updates=1 \
9396
logger.backend= \
97+
env.backend=gym \
9498
buffer.buffer_size=120
9599
""",
96100
"discrete_cql_online": """python sota-implementations/cql/discrete_cql_online.py \

sota-implementations/a2c/a2c_atari.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def main(cfg: DictConfig): # noqa: F821
4747
test_interval = cfg.logger.test_interval // frame_skip
4848

4949
# Create models (check utils_atari.py)
50-
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
50+
actor, critic, critic_head = make_ppo_models(
51+
cfg.env.env_name, device=device, gym_backend=cfg.env.backend
52+
)
5153
with from_module(actor).data.to("meta").to_module(actor):
5254
actor_eval = deepcopy(actor)
5355
actor_eval.eval()
@@ -107,7 +109,13 @@ def main(cfg: DictConfig): # noqa: F821
107109
)
108110

109111
# Create test environment
110-
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
112+
test_env = make_parallel_env(
113+
cfg.env.env_name,
114+
num_envs=1,
115+
device=device,
116+
gym_backend=cfg.env.backend,
117+
is_test=True,
118+
)
111119
test_env.set_seed(0)
112120
if cfg.logger.video:
113121
test_env = test_env.insert_transform(
@@ -162,7 +170,12 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
162170

163171
# Create collector
164172
collector = SyncDataCollector(
165-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
173+
create_env_fn=make_parallel_env(
174+
cfg.env.env_name,
175+
num_envs=cfg.env.num_envs,
176+
device=device,
177+
gym_backend=cfg.env.backend,
178+
),
166179
policy=actor,
167180
frames_per_batch=frames_per_batch,
168181
total_frames=total_frames,

sota-implementations/a2c/config_atari.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45
num_envs: 16
56

67
# collector

sota-implementations/a2c/utils_atari.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ParallelEnv,
2222
Resize,
2323
RewardSum,
24+
set_gym_backend,
2425
SignTransform,
2526
StepCounter,
2627
ToTensorImage,
@@ -45,27 +46,35 @@
4546

4647

4748
def make_base_env(
48-
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
49+
env_name="BreakoutNoFrameskip-v4",
50+
gym_backend="gymnasium",
51+
frame_skip=4,
52+
device="cpu",
53+
is_test=False,
4954
):
50-
env = GymEnv(
51-
env_name,
52-
frame_skip=frame_skip,
53-
from_pixels=True,
54-
pixels_only=False,
55-
device=device,
56-
)
55+
with set_gym_backend(gym_backend):
56+
env = GymEnv(
57+
env_name,
58+
frame_skip=frame_skip,
59+
from_pixels=True,
60+
pixels_only=False,
61+
device=device,
62+
)
5763
env = TransformedEnv(env)
5864
env.append_transform(NoopResetEnv(noops=30, random=True))
5965
if not is_test:
6066
env.append_transform(EndOfLifeTransform())
6167
return env
6268

6369

64-
def make_parallel_env(env_name, num_envs, device, is_test=False):
70+
def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
6571
env = ParallelEnv(
6672
num_envs,
67-
EnvCreator(lambda: make_base_env(env_name)),
73+
EnvCreator(
74+
lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test),
75+
),
6876
serial_for_single=True,
77+
gym_backend=gym_backend,
6978
device=device,
7079
)
7180
env = TransformedEnv(env)
@@ -175,9 +184,11 @@ def make_ppo_modules_pixels(proof_environment, device):
175184
return common_module, policy_module, value_module
176185

177186

178-
def make_ppo_models(env_name, device):
187+
def make_ppo_models(env_name, device, gym_backend):
179188

180-
proof_environment = make_parallel_env(env_name, 1, device="cpu")
189+
proof_environment = make_parallel_env(
190+
env_name, num_envs=1, device="cpu", gym_backend=gym_backend
191+
)
181192
common_module, policy_module, value_module = make_ppo_modules_pixels(
182193
proof_environment, device=device
183194
)

sota-implementations/dqn/config_atari.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ device: null
33
# Environment
44
env:
55
env_name: PongNoFrameskip-v4
6+
backend: gymnasium
67

78
# collector
89
collector:

sota-implementations/dqn/dqn_atari.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def main(cfg: DictConfig): # noqa: F821
4949
test_interval = cfg.logger.test_interval // frame_skip
5050

5151
# Make the components
52-
model = make_dqn_model(cfg.env.env_name, frame_skip, device=device)
52+
model = make_dqn_model(
53+
cfg.env.env_name,
54+
gym_backend=cfg.env.backend,
55+
frame_skip=frame_skip,
56+
device=device,
57+
)
5358
greedy_module = EGreedyModule(
5459
annealing_num_steps=cfg.collector.annealing_frames,
5560
eps_init=cfg.collector.eps_start,
@@ -114,7 +119,13 @@ def transform(td):
114119
)
115120

116121
# Create the test environment
117-
test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True)
122+
test_env = make_env(
123+
cfg.env.env_name,
124+
frame_skip,
125+
device,
126+
gym_backend=cfg.env.backend,
127+
is_test=True,
128+
)
118129
if cfg.logger.video:
119130
test_env.insert_transform(
120131
0,
@@ -154,7 +165,9 @@ def update(sampled_tensordict):
154165

155166
# Create the collector
156167
collector = SyncDataCollector(
157-
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
168+
create_env_fn=make_env(
169+
cfg.env.env_name, frame_skip, device, gym_backend=cfg.env.backend
170+
),
158171
policy=model_explore,
159172
frames_per_batch=frames_per_batch,
160173
total_frames=total_frames,

sota-implementations/dqn/utils_atari.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NoopResetEnv,
1717
Resize,
1818
RewardSum,
19+
set_gym_backend,
1920
SignTransform,
2021
StepCounter,
2122
ToTensorImage,
@@ -32,15 +33,16 @@
3233
# --------------------------------------------------------------------
3334

3435

35-
def make_env(env_name, frame_skip, device, is_test=False):
36-
env = GymEnv(
37-
env_name,
38-
frame_skip=frame_skip,
39-
from_pixels=True,
40-
pixels_only=False,
41-
device=device,
42-
categorical_action_encoding=True,
43-
)
36+
def make_env(env_name, frame_skip, device, gym_backend, is_test=False):
37+
with set_gym_backend(gym_backend):
38+
env = GymEnv(
39+
env_name,
40+
frame_skip=frame_skip,
41+
from_pixels=True,
42+
pixels_only=False,
43+
device=device,
44+
categorical_action_encoding=True,
45+
)
4446
env = TransformedEnv(env)
4547
env.append_transform(NoopResetEnv(noops=30, random=True))
4648
if not is_test:
@@ -94,8 +96,10 @@ def make_dqn_modules_pixels(proof_environment, device):
9496
return qvalue_module
9597

9698

97-
def make_dqn_model(env_name, frame_skip, device):
98-
proof_environment = make_env(env_name, frame_skip, device=device)
99+
def make_dqn_model(env_name, gym_backend, frame_skip, device):
100+
proof_environment = make_env(
101+
env_name, frame_skip, gym_backend=gym_backend, device=device
102+
)
99103
qvalue_module = make_dqn_modules_pixels(proof_environment, device=device)
100104
del proof_environment
101105
return qvalue_module

sota-implementations/impala/config_multi_node_ray.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45

56
# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
67
ray_init_config:

0 commit comments

Comments
 (0)